Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_cross_join.py
8424 views
1
from datetime import datetime
2
from zoneinfo import ZoneInfo
3
4
import pytest
5
6
import polars as pl
7
from polars._typing import MaintainOrderJoin
8
from polars.testing import assert_frame_equal
9
10
11
def test_cross_join_predicate_pushdown_block_16956() -> None:
12
lf = pl.LazyFrame(
13
[
14
[1718085600000, 1718172000000, 1718776800000],
15
[1718114400000, 1718200800000, 1718805600000],
16
],
17
schema=["start_datetime", "end_datetime"],
18
).cast(pl.Datetime("ms", "Europe/Amsterdam"))
19
20
assert (
21
lf.join(lf, how="cross")
22
.filter(
23
pl.col.end_datetime_right.is_between(
24
pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")
25
)
26
)
27
.select("start_datetime", "end_datetime_right")
28
).collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True)).to_dict(
29
as_series=False
30
) == {
31
"start_datetime": [
32
datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
33
datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
34
datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
35
datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
36
],
37
"end_datetime_right": [
38
datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
39
datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
40
datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
41
datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
42
],
43
}
44
45
46
def test_cross_join_raise_on_keys() -> None:
47
df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]})
48
49
with pytest.raises(ValueError):
50
df.join(df, how="cross", left_on="a", right_on="b")
51
52
53
def test_nested_loop_join() -> None:
54
left = pl.LazyFrame(
55
{
56
"a": [1, 2, 1, 3],
57
"b": [1, 2, 3, 4],
58
}
59
)
60
right = pl.LazyFrame(
61
{
62
"c": [4, 1, 2],
63
"d": [1, 2, 3],
64
}
65
)
66
67
actual = left.join_where(right, pl.col("a") != pl.col("c"))
68
plan = actual.explain()
69
assert "NESTED LOOP JOIN" in plan
70
expected = pl.DataFrame(
71
{
72
"a": [1, 1, 2, 2, 1, 1, 3, 3, 3],
73
"b": [1, 1, 2, 2, 3, 3, 4, 4, 4],
74
"c": [4, 2, 4, 1, 4, 2, 4, 1, 2],
75
"d": [1, 3, 1, 2, 1, 3, 1, 2, 3],
76
}
77
)
78
assert_frame_equal(
79
actual.collect(), expected, check_row_order=False, check_exact=True
80
)
81
82
83
def test_cross_join_chunking_panic_22793() -> None:
84
N = int(pl.thread_pool_size() ** 0.5) * 2
85
df = pl.DataFrame(
86
[pl.concat([pl.Series("a", [0]) for _ in range(N)]), pl.Series("b", [0] * N)],
87
)
88
assert_frame_equal(
89
df.lazy()
90
.join(pl.DataFrame().lazy(), how="cross")
91
.filter(pl.col("a") == pl.col("a"))
92
.collect(),
93
df.schema.to_frame(),
94
)
95
96
97
@pytest.mark.parametrize(
98
"maintain_order", ["left", "right", "left_right", "right_left"]
99
)
100
def test_cross_join_maintain_order_24663(maintain_order: MaintainOrderJoin) -> None:
101
df = pl.DataFrame({"x": [0, 1, 2, 3, 4]})
102
df2 = pl.DataFrame({"y": [0, 1, 2, 3, 4]})
103
primary = [x for x in range(5) for _ in range(5)]
104
secondary = [x for _ in range(5) for x in range(5)]
105
if maintain_order.startswith("left"):
106
expected = pl.DataFrame({"x": primary, "y": secondary})
107
else:
108
expected = pl.DataFrame({"x": secondary, "y": primary})
109
110
assert_frame_equal(
111
df.join(df2, how="cross", maintain_order=maintain_order), expected
112
)
113
114
# Test with fused filter as well.
115
assert_frame_equal(
116
df.lazy()
117
.join(df2.lazy(), how="cross", maintain_order=maintain_order)
118
.filter((pl.col.x + pl.col.y) % 2 == 0),
119
expected.lazy().filter((pl.col.x + pl.col.y) % 2 == 0),
120
)
121
122