Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_skip_batch_predicate.py
6939 views
1
from __future__ import annotations
2
3
import contextlib
4
import datetime
5
from typing import TYPE_CHECKING, Any, TypedDict
6
7
from hypothesis import Phase, given, settings
8
9
import polars as pl
10
from polars.meta import get_index_type
11
from polars.testing import assert_frame_equal, assert_series_equal
12
from polars.testing.parametric.strategies import series
13
14
if TYPE_CHECKING:
15
from collections.abc import Sequence
16
17
from polars._typing import PythonLiteral
18
19
20
class Case(TypedDict):
21
"""A test case for Skip Batch Predicate."""
22
23
min: Any | None
24
max: Any | None
25
null_count: int | None
26
len: int | None
27
can_skip: bool
28
29
30
def assert_skp_series(
31
name: str,
32
dtype: pl.DataType,
33
expr: pl.Expr,
34
cases: Sequence[Case],
35
) -> None:
36
sbp = expr._skip_batch_predicate({name: dtype})
37
38
df = pl.DataFrame(
39
[
40
pl.Series(f"{name}_min", [i["min"] for i in cases], dtype),
41
pl.Series(f"{name}_max", [i["max"] for i in cases], dtype),
42
pl.Series(f"{name}_nc", [i["null_count"] for i in cases], get_index_type()),
43
pl.Series("len", [i["len"] for i in cases], get_index_type()),
44
]
45
)
46
mask = pl.Series("can_skip", [i["can_skip"] for i in cases], pl.Boolean)
47
48
out = df.select(can_skip=sbp).to_series()
49
out = out.replace(None, False)
50
51
try:
52
assert_series_equal(out, mask)
53
except AssertionError:
54
print(sbp)
55
raise
56
57
58
def test_true_false_predicate() -> None:
59
true_sbp = pl.lit(True)._skip_batch_predicate({})
60
false_sbp = pl.lit(False)._skip_batch_predicate({})
61
null_sbp = pl.lit(None)._skip_batch_predicate({})
62
63
df = pl.DataFrame({"len": [1]})
64
65
out = df.select(
66
true=true_sbp,
67
false=false_sbp,
68
null=null_sbp,
69
)
70
71
assert_frame_equal(
72
out,
73
pl.DataFrame(
74
{
75
"true": [False],
76
"false": [True],
77
"null": [True],
78
}
79
),
80
)
81
82
83
def test_equality() -> None:
84
assert_skp_series(
85
"a",
86
pl.Int64(),
87
pl.col("a") == 5,
88
[
89
{"min": 1, "max": 2, "null_count": 0, "len": 42, "can_skip": True},
90
{"min": 6, "max": 7, "null_count": 0, "len": 42, "can_skip": True},
91
{"min": 1, "max": 7, "null_count": 0, "len": 42, "can_skip": False},
92
{"min": None, "max": None, "null_count": 42, "len": 42, "can_skip": True},
93
],
94
)
95
96
assert_skp_series(
97
"a",
98
pl.Int64(),
99
pl.col("a") != 0,
100
[
101
{"min": 0, "max": 0, "null_count": 6, "len": 7, "can_skip": False},
102
],
103
)
104
105
106
def test_datetimes() -> None:
107
d = datetime.datetime(2023, 4, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
108
td = datetime.timedelta
109
110
assert_skp_series(
111
"a",
112
pl.Datetime(time_zone=datetime.timezone.utc),
113
pl.col("a") == d,
114
[
115
{
116
"min": d - td(days=2),
117
"max": d - td(days=1),
118
"null_count": 0,
119
"len": 42,
120
"can_skip": True,
121
},
122
{
123
"min": d + td(days=1),
124
"max": d - td(days=2),
125
"null_count": 0,
126
"len": 42,
127
"can_skip": True,
128
},
129
{"min": d, "max": d, "null_count": 42, "len": 42, "can_skip": True},
130
{"min": d, "max": d, "null_count": 0, "len": 42, "can_skip": False},
131
{
132
"min": d - td(days=2),
133
"max": d + td(days=2),
134
"null_count": 0,
135
"len": 42,
136
"can_skip": False,
137
},
138
{
139
"min": d + td(days=1),
140
"max": None,
141
"null_count": None,
142
"len": None,
143
"can_skip": True,
144
},
145
],
146
)
147
148
149
@given(
150
s=series(
151
name="x",
152
min_size=1,
153
),
154
)
155
@settings(
156
report_multiple_bugs=False,
157
phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.target, Phase.explain),
158
)
159
def test_skip_batch_predicate_parametric(s: pl.Series) -> None:
160
name = "x"
161
dtype = s.dtype
162
163
value_a = s.slice(0, 1)
164
165
lit_a = pl.lit(value_a[0], dtype)
166
167
exprs = [
168
pl.col.x == lit_a,
169
pl.col.x != lit_a,
170
pl.col.x.eq_missing(lit_a),
171
pl.col.x.ne_missing(lit_a),
172
pl.col.x.is_null(),
173
pl.col.x.is_not_null(),
174
]
175
176
try:
177
_ = s > value_a
178
exprs += [
179
pl.col.x > lit_a,
180
pl.col.x >= lit_a,
181
pl.col.x < lit_a,
182
pl.col.x <= lit_a,
183
pl.col.x.is_in(pl.Series([None, value_a[0]], dtype=dtype)),
184
]
185
186
if s.len() > 1:
187
value_b = s.slice(1, 1)
188
lit_b = pl.lit(value_b[0], dtype)
189
190
exprs += [
191
pl.col.x.is_between(lit_a, lit_b),
192
pl.col.x.is_in(pl.Series([value_a[0], value_b[0]], dtype=dtype)),
193
]
194
except Exception as _:
195
pass
196
197
for expr in exprs:
198
sbp = expr._skip_batch_predicate({name: dtype})
199
200
if sbp is None:
201
continue
202
203
mins: list[PythonLiteral | None] = [None]
204
with contextlib.suppress(Exception):
205
mins = [s.min()]
206
207
maxs: list[PythonLiteral | None] = [None]
208
with contextlib.suppress(Exception):
209
maxs = [s.max()]
210
211
null_counts = [s.null_count()]
212
lengths = [s.len()]
213
214
df = pl.DataFrame(
215
[
216
pl.Series(f"{name}_min", mins, dtype),
217
pl.Series(f"{name}_max", maxs, dtype),
218
pl.Series(f"{name}_nc", null_counts, get_index_type()),
219
pl.Series("len", lengths, get_index_type()),
220
]
221
)
222
223
can_skip = df.select(can_skip=sbp).fill_null(False).to_series()[0]
224
if can_skip:
225
try:
226
assert s.to_frame().filter(expr).height == 0
227
except Exception as _:
228
print(expr)
229
print(sbp)
230
print(df)
231
print(s.to_frame().filter(expr))
232
233
raise
234
235