Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_inequality_join.py
6939 views
1
from __future__ import annotations
2
3
from datetime import datetime
4
from typing import TYPE_CHECKING, Any
5
6
import hypothesis.strategies as st
7
import numpy as np
8
import pytest
9
from hypothesis import given
10
11
import polars as pl
12
from polars.testing import assert_frame_equal
13
from polars.testing.parametric.strategies import series
14
15
if TYPE_CHECKING:
16
from hypothesis.strategies import DrawFn, SearchStrategy
17
18
19
@pytest.mark.parametrize(
20
("pred_1", "pred_2"),
21
[
22
(pl.col("time") > pl.col("time_right"), pl.col("cost") < pl.col("cost_right")),
23
(pl.col("time_right") < pl.col("time"), pl.col("cost_right") > pl.col("cost")),
24
],
25
)
26
def test_self_join(pred_1: pl.Expr, pred_2: pl.Expr) -> None:
27
west = pl.DataFrame(
28
{
29
"t_id": [404, 498, 676, 742],
30
"time": [100, 140, 80, 90],
31
"cost": [6, 11, 10, 5],
32
"cores": [4, 2, 1, 4],
33
}
34
)
35
36
actual = west.join_where(west, pred_1, pred_2)
37
38
expected = pl.DataFrame(
39
{
40
"t_id": [742, 404],
41
"time": [90, 100],
42
"cost": [5, 6],
43
"cores": [4, 4],
44
"t_id_right": [676, 676],
45
"time_right": [80, 80],
46
"cost_right": [10, 10],
47
"cores_right": [1, 1],
48
}
49
)
50
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
51
52
53
def test_basic_ie_join() -> None:
54
east = pl.DataFrame(
55
{
56
"id": [100, 101, 102],
57
"dur": [140, 100, 90],
58
"rev": [12, 12, 5],
59
"cores": [2, 8, 4],
60
}
61
)
62
west = pl.DataFrame(
63
{
64
"t_id": [404, 498, 676, 742],
65
"time": [100, 140, 80, 90],
66
"cost": [6, 11, 10, 5],
67
"cores": [4, 2, 1, 4],
68
}
69
)
70
71
actual = east.join_where(
72
west,
73
pl.col("dur") < pl.col("time"),
74
pl.col("rev") > pl.col("cost"),
75
)
76
77
expected = pl.DataFrame(
78
{
79
"id": [101],
80
"dur": [100],
81
"rev": [12],
82
"cores": [8],
83
"t_id": [498],
84
"time": [140],
85
"cost": [11],
86
"cores_right": [2],
87
}
88
)
89
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
90
91
92
@given(
93
offset=st.integers(-6, 5),
94
length=st.integers(0, 6),
95
)
96
def test_ie_join_with_slice(offset: int, length: int) -> None:
97
east = pl.DataFrame(
98
{
99
"id": [100, 101, 102],
100
"dur": [120, 140, 160],
101
"rev": [12, 14, 16],
102
"cores": [2, 8, 4],
103
}
104
).lazy()
105
west = pl.DataFrame(
106
{
107
"t_id": [404, 498, 676, 742],
108
"time": [90, 130, 150, 170],
109
"cost": [9, 13, 15, 16],
110
"cores": [4, 2, 1, 4],
111
}
112
).lazy()
113
114
actual = (
115
east.join_where(
116
west,
117
pl.col("dur") < pl.col("time"),
118
pl.col("rev") < pl.col("cost"),
119
)
120
.slice(offset, length)
121
.collect()
122
)
123
124
expected_full = pl.DataFrame(
125
{
126
"id": [101, 101, 100, 100, 100],
127
"dur": [140, 140, 120, 120, 120],
128
"rev": [14, 14, 12, 12, 12],
129
"cores": [8, 8, 2, 2, 2],
130
"t_id": [676, 742, 498, 676, 742],
131
"time": [150, 170, 130, 150, 170],
132
"cost": [15, 16, 13, 15, 16],
133
"cores_right": [1, 4, 2, 1, 4],
134
}
135
)
136
# The ordering of the result is arbitrary, so we can
137
# only verify that each row of the slice is present in the full expected result.
138
assert len(actual) == len(expected_full.slice(offset, length))
139
140
expected_rows = set(expected_full.iter_rows())
141
for row in actual.iter_rows():
142
assert row in expected_rows, f"{row} not in expected rows"
143
144
145
def test_ie_join_with_expressions() -> None:
146
east = pl.DataFrame(
147
{
148
"id": [100, 101, 102],
149
"dur": [70, 50, 45],
150
"rev": [12, 12, 5],
151
"cores": [2, 8, 4],
152
}
153
)
154
west = pl.DataFrame(
155
{
156
"t_id": [404, 498, 676, 742],
157
"time": [100, 140, 80, 90],
158
"cost": [12, 22, 20, 10],
159
"cores": [4, 2, 1, 4],
160
}
161
)
162
163
actual = east.join_where(
164
west,
165
(pl.col("dur") * 2) < pl.col("time"),
166
pl.col("rev") > (pl.col("cost").cast(pl.Int32) // 2).cast(pl.Int64),
167
)
168
169
expected = pl.DataFrame(
170
{
171
"id": [101],
172
"dur": [50],
173
"rev": [12],
174
"cores": [8],
175
"t_id": [498],
176
"time": [140],
177
"cost": [22],
178
"cores_right": [2],
179
}
180
)
181
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
182
183
184
@pytest.mark.parametrize(
185
"range_constraint",
186
[
187
[
188
# can write individual components
189
pl.col("time") >= pl.col("start_time"),
190
pl.col("time") < pl.col("end_time"),
191
],
192
[
193
# or a single `is_between` expression
194
pl.col("time").is_between("start_time", "end_time", closed="left")
195
],
196
],
197
)
198
def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None:
199
left = pl.DataFrame(
200
{
201
"id": [0, 1, 2, 3, 4, 5],
202
"group": [0, 0, 0, 1, 1, 1],
203
"time": [
204
datetime(2024, 8, 26, 15, 34, 30),
205
datetime(2024, 8, 26, 15, 35, 30),
206
datetime(2024, 8, 26, 15, 36, 30),
207
datetime(2024, 8, 26, 15, 37, 30),
208
datetime(2024, 8, 26, 15, 38, 0),
209
datetime(2024, 8, 26, 15, 39, 0),
210
],
211
}
212
)
213
right = pl.DataFrame(
214
{
215
"id": [0, 1, 2],
216
"group": [0, 1, 1],
217
"start_time": [
218
datetime(2024, 8, 26, 15, 34, 0),
219
datetime(2024, 8, 26, 15, 35, 0),
220
datetime(2024, 8, 26, 15, 38, 0),
221
],
222
"end_time": [
223
datetime(2024, 8, 26, 15, 36, 0),
224
datetime(2024, 8, 26, 15, 37, 0),
225
datetime(2024, 8, 26, 15, 39, 0),
226
],
227
}
228
)
229
230
actual = left.join_where(right, *range_constraint).select("id", "id_right")
231
232
expected = pl.DataFrame(
233
{
234
"id": [0, 1, 1, 2, 4],
235
"id_right": [0, 0, 1, 1, 2],
236
}
237
)
238
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
239
240
q = (
241
left.lazy()
242
.join_where(
243
right.lazy(),
244
pl.col("group_right") == pl.col("group"),
245
*range_constraint,
246
)
247
.select("id", "id_right", "group")
248
.sort("id")
249
)
250
251
explained = q.explain()
252
assert "INNER JOIN" in explained
253
assert "FILTER" in explained
254
actual = q.collect()
255
256
expected = (
257
left.join(right, how="cross")
258
.filter(pl.col("group") == pl.col("group_right"), *range_constraint)
259
.select("id", "id_right", "group")
260
.sort("id")
261
)
262
assert_frame_equal(actual, expected, check_exact=True)
263
264
q = (
265
left.lazy()
266
.join_where(
267
right.lazy(),
268
pl.col("group") != pl.col("group_right"),
269
*range_constraint,
270
)
271
.select("id", "id_right", "group")
272
.sort("id")
273
)
274
275
explained = q.explain()
276
assert "IEJOIN" in explained
277
assert "FILTER" in explained
278
actual = q.collect()
279
280
expected = (
281
left.join(right, how="cross")
282
.filter(pl.col("group") != pl.col("group_right"), *range_constraint)
283
.select("id", "id_right", "group")
284
.sort("id")
285
)
286
assert_frame_equal(actual, expected, check_exact=True)
287
288
q = (
289
left.lazy()
290
.join_where(
291
right.lazy(),
292
pl.col("group") != pl.col("group_right"),
293
)
294
.select("id", "group", "group_right")
295
.sort("id")
296
.select("group", "group_right")
297
)
298
299
explained = q.explain()
300
assert "NESTED LOOP" in explained
301
actual = q.collect()
302
assert actual.to_dict(as_series=False) == {
303
"group": [0, 0, 0, 0, 0, 0, 1, 1, 1],
304
"group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0],
305
}
306
307
308
def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr:
309
if op == "<":
310
return pl.col(col1) < pl.col(col2)
311
elif op == "<=":
312
return pl.col(col1) <= pl.col(col2)
313
elif op == ">":
314
return pl.col(col1) > pl.col(col2)
315
elif op == ">=":
316
return pl.col(col1) >= pl.col(col2)
317
else:
318
message = f"Invalid operator '{op}'"
319
raise ValueError(message)
320
321
322
def operators() -> SearchStrategy[str]:
323
valid_operators = ["<", "<=", ">", ">="]
324
return st.sampled_from(valid_operators)
325
326
327
@st.composite
328
def east_df(
329
draw: DrawFn, with_nulls: bool = False, use_floats: bool = False
330
) -> pl.DataFrame:
331
height = draw(st.integers(min_value=0, max_value=20))
332
333
if use_floats:
334
dur_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)
335
rev_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)
336
dur_dtype: type[pl.DataType] = pl.Float32
337
rev_dtype: type[pl.DataType] = pl.Float32
338
else:
339
dur_strategy = st.integers(min_value=100, max_value=105)
340
rev_strategy = st.integers(min_value=9, max_value=13)
341
dur_dtype = pl.Int64
342
rev_dtype = pl.Int64
343
344
if with_nulls:
345
dur_strategy = dur_strategy | st.none()
346
rev_strategy = rev_strategy | st.none()
347
348
cores_strategy = st.integers(min_value=1, max_value=10)
349
350
ids = np.arange(0, height)
351
dur = draw(st.lists(dur_strategy, min_size=height, max_size=height))
352
rev = draw(st.lists(rev_strategy, min_size=height, max_size=height))
353
cores = draw(st.lists(cores_strategy, min_size=height, max_size=height))
354
355
return pl.DataFrame(
356
[
357
pl.Series("id", ids, dtype=pl.Int64),
358
pl.Series("dur", dur, dtype=dur_dtype),
359
pl.Series("rev", rev, dtype=rev_dtype),
360
pl.Series("cores", cores, dtype=pl.Int64),
361
]
362
)
363
364
365
@st.composite
366
def west_df(
367
draw: DrawFn, with_nulls: bool = False, use_floats: bool = False
368
) -> pl.DataFrame:
369
height = draw(st.integers(min_value=0, max_value=20))
370
371
if use_floats:
372
time_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)
373
cost_strategy: SearchStrategy[Any] = st.floats(allow_nan=True)
374
time_dtype: type[pl.DataType] = pl.Float32
375
cost_dtype: type[pl.DataType] = pl.Float32
376
else:
377
time_strategy = st.integers(min_value=100, max_value=105)
378
cost_strategy = st.integers(min_value=9, max_value=13)
379
time_dtype = pl.Int64
380
cost_dtype = pl.Int64
381
382
if with_nulls:
383
time_strategy = time_strategy | st.none()
384
cost_strategy = cost_strategy | st.none()
385
386
cores_strategy = st.integers(min_value=1, max_value=10)
387
388
t_id = np.arange(100, 100 + height)
389
time = draw(st.lists(time_strategy, min_size=height, max_size=height))
390
cost = draw(st.lists(cost_strategy, min_size=height, max_size=height))
391
cores = draw(st.lists(cores_strategy, min_size=height, max_size=height))
392
393
return pl.DataFrame(
394
[
395
pl.Series("t_id", t_id, dtype=pl.Int64),
396
pl.Series("time", time, dtype=time_dtype),
397
pl.Series("cost", cost, dtype=cost_dtype),
398
pl.Series("cores", cores, dtype=pl.Int64),
399
]
400
)
401
402
403
@given(
404
east=east_df(),
405
west=west_df(),
406
op1=operators(),
407
op2=operators(),
408
)
409
def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) -> None:
410
expr0 = _inequality_expression("dur", op1, "time")
411
expr1 = _inequality_expression("rev", op2, "cost")
412
413
actual = east.join_where(west, expr0 & expr1)
414
415
expected = east.join(west, how="cross").filter(expr0 & expr1)
416
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
417
418
419
@given(
420
east=east_df(with_nulls=True),
421
west=west_df(with_nulls=True),
422
op1=operators(),
423
op2=operators(),
424
)
425
def test_ie_join_with_nulls(
426
east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str
427
) -> None:
428
expr0 = _inequality_expression("dur", op1, "time")
429
expr1 = _inequality_expression("rev", op2, "cost")
430
431
actual = east.join_where(west, expr0 & expr1)
432
433
expected = east.join(west, how="cross").filter(expr0 & expr1)
434
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
435
436
437
@given(
438
east=east_df(use_floats=True),
439
west=west_df(use_floats=True),
440
op1=operators(),
441
op2=operators(),
442
)
443
def test_ie_join_with_floats(
444
east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str
445
) -> None:
446
expr0 = _inequality_expression("dur", op1, "time")
447
expr1 = _inequality_expression("rev", op2, "cost")
448
449
actual = east.join_where(west, expr0, expr1)
450
451
expected = east.join(west, how="cross").filter(expr0 & expr1)
452
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
453
454
455
def test_raise_invalid_input_join_where() -> None:
456
df = pl.DataFrame({"id": [1, 2]})
457
with pytest.raises(
458
pl.exceptions.InvalidOperationError,
459
match="expected join keys/predicates",
460
):
461
df.join_where(df)
462
463
464
def test_ie_join_use_keys_multiple() -> None:
465
a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})
466
b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})
467
468
assert a.join_where(
469
b,
470
pl.col.a >= pl.col.b,
471
pl.col.a <= pl.col.b,
472
).collect().sort("x_right").to_dict(as_series=False) == {
473
"a": [2, 2, 2],
474
"x": [2, 2, 2],
475
"b": [2, 2, 2],
476
"x_right": [1, 3, 7],
477
}
478
479
480
@given(
481
left=series(
482
dtype=pl.Int64,
483
strategy=st.integers(min_value=0, max_value=10) | st.none(),
484
max_size=10,
485
),
486
right=series(
487
dtype=pl.Int64,
488
strategy=st.integers(min_value=-10, max_value=10) | st.none(),
489
max_size=10,
490
),
491
op=operators(),
492
)
493
def test_single_inequality(left: pl.Series, right: pl.Series, op: str) -> None:
494
expr = _inequality_expression("x", op, "y")
495
496
left_df = pl.DataFrame(
497
{
498
"id": np.arange(len(left)),
499
"x": left,
500
}
501
)
502
right_df = pl.DataFrame(
503
{
504
"id": np.arange(len(right)),
505
"y": right,
506
}
507
)
508
509
actual = left_df.join_where(right_df, expr)
510
511
expected = left_df.join(right_df, how="cross").filter(expr)
512
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
513
514
515
@given(
516
offset=st.integers(-6, 5),
517
length=st.integers(0, 6),
518
)
519
def test_single_inequality_with_slice(offset: int, length: int) -> None:
520
left = pl.DataFrame(
521
{
522
"id": list(range(8)),
523
"x": [0, 1, 1, 2, 3, 5, 5, 7],
524
}
525
)
526
right = pl.DataFrame(
527
{
528
"id": list(range(6)),
529
"y": [-1, 2, 4, 4, 6, 9],
530
}
531
)
532
533
expr = pl.col("x") > pl.col("y")
534
actual = left.join_where(right, expr).slice(offset, length)
535
536
expected_full = left.join(right, how="cross").filter(expr)
537
538
assert len(actual) == len(expected_full.slice(offset, length))
539
540
expected_rows = set(expected_full.iter_rows())
541
for row in actual.iter_rows():
542
assert row in expected_rows, f"{row} not in expected rows"
543
544
545
def test_ie_join_projection_pd_19005() -> None:
546
lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).with_row_index()
547
q = (
548
lf.join_where(
549
lf,
550
pl.col.index < pl.col.index_right,
551
pl.col.index.cast(pl.Int64) + pl.col.a > pl.col.a_right,
552
)
553
.group_by(pl.col.index)
554
.agg(pl.col.index_right)
555
)
556
557
out = q.collect()
558
assert out.schema == pl.Schema(
559
[("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))]
560
)
561
assert out.shape == (0, 2)
562
563
564
def test_single_sided_predicate() -> None:
565
left = pl.LazyFrame({"a": [1, -1, 2]}).with_row_index()
566
right = pl.LazyFrame({"b": [1, 2]})
567
568
result = (
569
left.join_where(right, pl.col.index >= pl.col.a)
570
.collect()
571
.sort("index", "a", "b")
572
)
573
expected = pl.DataFrame(
574
{
575
"index": pl.Series([1, 1, 2, 2], dtype=pl.get_index_type()),
576
"a": [-1, -1, 2, 2],
577
"b": [1, 2, 1, 2],
578
}
579
)
580
assert_frame_equal(result, expected)
581
582
583
def test_join_on_strings() -> None:
584
df = pl.LazyFrame(
585
{
586
"a": ["a", "b", "c"],
587
"b": ["b", "b", "b"],
588
}
589
)
590
591
q = df.join_where(df, pl.col("a").ge(pl.col("a_right")))
592
593
assert "NESTED LOOP JOIN" in q.explain()
594
# Note: Output is flaky without sort when POLARS_MAX_THREADS=1
595
assert q.collect().sort(pl.all()).to_dict(as_series=False) == {
596
"a": ["a", "b", "b", "c", "c", "c"],
597
"b": ["b", "b", "b", "b", "b", "b"],
598
"a_right": ["a", "a", "b", "a", "b", "c"],
599
"b_right": ["b", "b", "b", "b", "b", "b"],
600
}
601
602
603
def test_join_partial_column_name_overlap_19119() -> None:
604
left = pl.LazyFrame({"a": [1], "b": [2]})
605
right = pl.LazyFrame({"a": [2], "d": [0]})
606
607
q = left.join_where(right, pl.col("a") > pl.col("d"))
608
609
assert q.collect().to_dict(as_series=False) == {
610
"a": [1],
611
"b": [2],
612
"a_right": [2],
613
"d": [0],
614
}
615
616
617
def test_join_predicate_pushdown_19580() -> None:
618
left = pl.LazyFrame(
619
{
620
"a": [1, 2, 3, 1],
621
"b": [1, 2, 3, 4],
622
"c": [2, 3, 4, 5],
623
}
624
)
625
626
right = pl.LazyFrame({"a": [1, 3], "c": [2, 4], "d": [6, 3]})
627
628
q = left.join_where(
629
right,
630
pl.col("b") < pl.col("c_right"),
631
pl.col("a") < pl.col("a_right"),
632
pl.col("a") < pl.col("d"),
633
)
634
635
expect = (
636
left.join(right, how="cross")
637
.collect()
638
.filter(
639
(pl.col("a") < pl.col("d"))
640
& (pl.col("b") < pl.col("c_right"))
641
& (pl.col("a") < pl.col("a_right"))
642
)
643
)
644
645
assert_frame_equal(expect, q.collect(), check_row_order=False)
646
647
648
def test_join_where_literal_20061() -> None:
649
df_left = pl.DataFrame(
650
{"id": [1, 2, 3], "value_left": [10, 20, 30], "flag": [1, 0, 1]}
651
)
652
653
df_right = pl.DataFrame(
654
{
655
"id": [1, 2, 3],
656
"value_right": [5, 5, 25],
657
"flag": [1, 0, 1],
658
}
659
)
660
661
assert df_left.join_where(
662
df_right,
663
pl.col("value_left") > pl.col("value_right"),
664
pl.col("flag_right") == pl.lit(1, dtype=pl.Int8),
665
).sort(pl.all()).to_dict(as_series=False) == {
666
"id": [1, 2, 3, 3],
667
"value_left": [10, 20, 30, 30],
668
"flag": [1, 0, 1, 1],
669
"id_right": [1, 1, 1, 3],
670
"value_right": [5, 5, 5, 25],
671
"flag_right": [1, 1, 1, 1],
672
}
673
674
675
def test_boolean_predicate_join_where() -> None:
676
urls = pl.LazyFrame({"url": "abcd.com/page"})
677
categories = pl.LazyFrame({"base_url": "abcd.com", "category": "landing page"})
678
assert (
679
"NESTED LOOP JOIN"
680
in urls.join_where(
681
categories, pl.col("url").str.starts_with(pl.col("base_url"))
682
).explain()
683
)
684
685