Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/benchmark/test_join_where.py
6939 views
1
"""Benchmark tests for join_where with inequality conditions."""
2
3
from __future__ import annotations
4
5
import numpy as np
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import ColumnNotFoundError
10
from polars.testing import assert_frame_equal
11
12
pytestmark = pytest.mark.benchmark()
13
14
15
def test_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:
16
east, west = east_west
17
result = (
18
east.lazy()
19
.join_where(
20
west.lazy(),
21
[pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")],
22
)
23
.collect()
24
)
25
26
assert len(result) > 0
27
28
29
def test_non_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:
30
east, west = east_west
31
result = (
32
east.lazy()
33
.join_where(
34
west.lazy(),
35
[pl.col("dur") <= pl.col("time"), pl.col("rev") >= pl.col("cost")],
36
)
37
.collect()
38
)
39
40
assert len(result) > 0
41
42
43
def test_single_inequality(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None:
44
east, west = east_west
45
result = (
46
east.lazy()
47
# Reduce the number of results by scaling LHS dur column up
48
.with_columns((pl.col("dur") * 30).alias("scaled_dur"))
49
.join_where(
50
west.lazy(),
51
pl.col("scaled_dur") < pl.col("time"),
52
)
53
.collect()
54
)
55
56
assert len(result) > 0
57
58
59
@pytest.fixture(scope="module")
60
def east_west() -> tuple[pl.DataFrame, pl.DataFrame]:
61
num_rows_left, num_rows_right = 50_000, 5_000
62
rng = np.random.default_rng(42)
63
64
# Generate two separate datasets where revenue/cost are linearly related to
65
# duration/time, but add some noise to the west table so that there are some
66
# rows where the cost for the same or greater time will be less than the east table.
67
east_dur = rng.integers(1_000, 50_000, num_rows_left)
68
east_rev = (east_dur * 0.123).astype(np.int32)
69
west_time = rng.integers(1_000, 50_000, num_rows_right)
70
west_cost = west_time * 0.123
71
west_cost += rng.normal(0.0, 1.0, num_rows_right)
72
west_cost = west_cost.astype(np.int32)
73
74
east = pl.DataFrame(
75
{
76
"id": np.arange(0, num_rows_left),
77
"dur": east_dur,
78
"rev": east_rev,
79
"cores": rng.integers(1, 10, num_rows_left),
80
}
81
)
82
west = pl.DataFrame(
83
{
84
"t_id": np.arange(0, num_rows_right),
85
"time": west_time,
86
"cost": west_cost,
87
"cores": rng.integers(1, 10, num_rows_right),
88
}
89
)
90
91
return east, west
92
93
94
def test_join_where_invalid_column() -> None:
95
df = pl.DataFrame({"x": 1})
96
with pytest.raises(ColumnNotFoundError, match="y"):
97
df.join_where(df, pl.col("x") < pl.col("y"))
98
99
# Nested column
100
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True]})
101
df2 = pl.DataFrame(
102
{
103
"a": [2, 3, 4],
104
"c": ["a", "b", "c"],
105
}
106
)
107
with pytest.raises(ColumnNotFoundError, match="d"):
108
df = df1.join_where(
109
df2,
110
((pl.col("a") - pl.col("b")) > (pl.col("c") == "a").cast(pl.Int32))
111
> (pl.col("a") - pl.col("d")),
112
)
113
114
115
def test_join_where_not_elementwise_24134() -> None:
116
out = (
117
pl.LazyFrame({"a": [0, 1, 2, 16]})
118
.join_where(
119
pl.LazyFrame({"b": [0, 1, 2, 16]}),
120
pl.col.a == pl.len(),
121
)
122
.collect()
123
)
124
125
expected = pl.DataFrame({"a": [16, 16, 16, 16], "b": [0, 1, 2, 16]})
126
assert_frame_equal(out, expected, check_row_order=False)
127
128