Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_cross_join.py
6939 views
1
from datetime import datetime
2
from zoneinfo import ZoneInfo
3
4
import pytest
5
6
import polars as pl
7
from polars.testing import assert_frame_equal
8
9
10
def test_cross_join_predicate_pushdown_block_16956() -> None:
11
lf = pl.LazyFrame(
12
[
13
[1718085600000, 1718172000000, 1718776800000],
14
[1718114400000, 1718200800000, 1718805600000],
15
],
16
schema=["start_datetime", "end_datetime"],
17
).cast(pl.Datetime("ms", "Europe/Amsterdam"))
18
19
assert (
20
lf.join(lf, how="cross")
21
.filter(
22
pl.col.end_datetime_right.is_between(
23
pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")
24
)
25
)
26
.select("start_datetime", "end_datetime_right")
27
).collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True)).to_dict(
28
as_series=False
29
) == {
30
"start_datetime": [
31
datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
32
datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
33
datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
34
datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
35
],
36
"end_datetime_right": [
37
datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
38
datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
39
datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
40
datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
41
],
42
}
43
44
45
def test_cross_join_raise_on_keys() -> None:
46
df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]})
47
48
with pytest.raises(ValueError):
49
df.join(df, how="cross", left_on="a", right_on="b")
50
51
52
def test_nested_loop_join() -> None:
53
left = pl.LazyFrame(
54
{
55
"a": [1, 2, 1, 3],
56
"b": [1, 2, 3, 4],
57
}
58
)
59
right = pl.LazyFrame(
60
{
61
"c": [4, 1, 2],
62
"d": [1, 2, 3],
63
}
64
)
65
66
actual = left.join_where(right, pl.col("a") != pl.col("c"))
67
plan = actual.explain()
68
assert "NESTED LOOP JOIN" in plan
69
expected = pl.DataFrame(
70
{
71
"a": [1, 1, 2, 2, 1, 1, 3, 3, 3],
72
"b": [1, 1, 2, 2, 3, 3, 4, 4, 4],
73
"c": [4, 2, 4, 1, 4, 2, 4, 1, 2],
74
"d": [1, 3, 1, 2, 1, 3, 1, 2, 3],
75
}
76
)
77
assert_frame_equal(
78
actual.collect(), expected, check_row_order=False, check_exact=True
79
)
80
81
82
def test_cross_join_chunking_panic_22793() -> None:
83
N = int(pl.thread_pool_size() ** 0.5) * 2
84
df = pl.DataFrame(
85
[pl.concat([pl.Series("a", [0]) for _ in range(N)]), pl.Series("b", [0] * N)],
86
)
87
assert_frame_equal(
88
df.lazy()
89
.join(pl.DataFrame().lazy(), how="cross")
90
.filter(pl.col("a") == pl.col("a"))
91
.collect(),
92
df.schema.to_frame(),
93
)
94
95