Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_pyarrow_dataset.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time
4
from typing import TYPE_CHECKING, Callable
5
6
import pyarrow.dataset as ds
7
8
import polars as pl
9
from polars.testing import assert_frame_equal
10
11
if TYPE_CHECKING:
12
from pathlib import Path
13
14
import pytest
15
16
17
def helper_dataset_test(
18
file_path: Path,
19
query: Callable[[pl.LazyFrame], pl.LazyFrame],
20
batch_size: int | None = None,
21
n_expected: int | None = None,
22
check_predicate_pushdown: bool = False,
23
) -> None:
24
dset = ds.dataset(file_path, format="ipc")
25
q = pl.scan_ipc(file_path).pipe(query)
26
27
expected = q.collect()
28
out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect()
29
assert_frame_equal(out, expected)
30
if n_expected is not None:
31
assert len(out) == n_expected
32
33
if check_predicate_pushdown:
34
assert "FILTER" not in q.explain()
35
36
37
# @pytest.mark.write_disk()
38
def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None:
39
file_path = tmp_path / "small.ipc"
40
df.write_ipc(file_path)
41
42
helper_dataset_test(
43
file_path,
44
lambda lf: lf.filter("bools").select("bools", "floats", "date"),
45
n_expected=1,
46
check_predicate_pushdown=True,
47
)
48
helper_dataset_test(
49
file_path,
50
lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"),
51
n_expected=2,
52
check_predicate_pushdown=True,
53
)
54
helper_dataset_test(
55
file_path,
56
lambda lf: lf.filter(pl.col("int_nulls").is_null()).select(
57
"bools", "floats", "date"
58
),
59
n_expected=1,
60
check_predicate_pushdown=True,
61
)
62
helper_dataset_test(
63
file_path,
64
lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select(
65
"bools", "floats", "date"
66
),
67
n_expected=2,
68
check_predicate_pushdown=True,
69
)
70
helper_dataset_test(
71
file_path,
72
lambda lf: lf.filter(
73
pl.col("int_nulls").is_not_null() == pl.col("bools")
74
).select("bools", "floats", "date"),
75
n_expected=0,
76
check_predicate_pushdown=True,
77
)
78
# this equality on a column with nulls fails as pyarrow has different
79
# handling kleene logic. We leave it for now and document it in the function.
80
helper_dataset_test(
81
file_path,
82
lambda lf: lf.filter(pl.col("int") == 10).select(
83
"bools", "floats", "int_nulls"
84
),
85
n_expected=0,
86
check_predicate_pushdown=True,
87
)
88
helper_dataset_test(
89
file_path,
90
lambda lf: lf.filter(pl.col("int") != 10).select(
91
"bools", "floats", "int_nulls"
92
),
93
n_expected=3,
94
check_predicate_pushdown=True,
95
)
96
97
for closed, n_expected in zip(["both", "left", "right", "none"], [3, 2, 2, 1]):
98
helper_dataset_test(
99
file_path,
100
lambda lf, closed=closed: lf.filter( # type: ignore[misc]
101
pl.col("int").is_between(1, 3, closed=closed)
102
).select("bools", "floats", "date"),
103
n_expected=n_expected,
104
check_predicate_pushdown=True,
105
)
106
# this predicate is not supported by pyarrow
107
# check if we still do it on our side
108
helper_dataset_test(
109
file_path,
110
lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select(
111
"bools", "floats", "date"
112
),
113
n_expected=0,
114
)
115
# temporal types
116
helper_dataset_test(
117
file_path,
118
lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select(
119
"bools", "floats", "date"
120
),
121
n_expected=1,
122
check_predicate_pushdown=True,
123
)
124
helper_dataset_test(
125
file_path,
126
lambda lf: lf.filter(
127
pl.col("datetime") > datetime(1970, 1, 1, second=13)
128
).select("bools", "floats", "date"),
129
n_expected=1,
130
check_predicate_pushdown=True,
131
)
132
# not yet supported in pyarrow
133
helper_dataset_test(
134
file_path,
135
lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select(
136
"bools", "time", "date"
137
),
138
n_expected=3,
139
check_predicate_pushdown=True,
140
)
141
# pushdown is_in
142
helper_dataset_test(
143
file_path,
144
lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select(
145
"bools", "floats", "date"
146
),
147
n_expected=2,
148
check_predicate_pushdown=True,
149
)
150
helper_dataset_test(
151
file_path,
152
lambda lf: lf.filter(
153
pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)])
154
).select("bools", "floats", "date"),
155
n_expected=2,
156
check_predicate_pushdown=True,
157
)
158
helper_dataset_test(
159
file_path,
160
lambda lf: lf.filter(
161
pl.col("datetime").is_in(
162
[
163
datetime(1970, 1, 1, 0, 0, 12, 341234),
164
datetime(1970, 1, 1, 0, 0, 13, 241324),
165
]
166
)
167
).select("bools", "floats", "date"),
168
n_expected=2,
169
check_predicate_pushdown=True,
170
)
171
helper_dataset_test(
172
file_path,
173
lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select(
174
"bools", "floats", "date"
175
),
176
n_expected=3,
177
check_predicate_pushdown=True,
178
)
179
helper_dataset_test(
180
file_path,
181
lambda lf: lf.filter(pl.col("cat").is_in([])).select("bools", "floats", "date"),
182
n_expected=0,
183
)
184
helper_dataset_test(
185
file_path,
186
lambda lf: lf.select(pl.exclude("enum")),
187
batch_size=2,
188
n_expected=3,
189
)
190
191
# direct filter
192
helper_dataset_test(
193
file_path,
194
lambda lf: lf.filter(pl.Series([True, False, True])).select(
195
"bools", "floats", "date"
196
),
197
n_expected=2,
198
)
199
200
helper_dataset_test(
201
file_path,
202
lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select(
203
"bools", "floats"
204
),
205
n_expected=1,
206
check_predicate_pushdown=True,
207
)
208
209
210
def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None:
211
df0 = pl.DataFrame({"a": [1, 2, 3]})
212
213
df1 = pl.DataFrame({"a": [1, 2]})
214
215
file_path_0 = tmp_path / "0.parquet"
216
file_path_1 = tmp_path / "1.parquet"
217
218
df0.write_parquet(file_path_0)
219
df1.write_parquet(file_path_1)
220
221
ds0 = ds.dataset(file_path_0, format="parquet")
222
ds1 = ds.dataset(file_path_1, format="parquet")
223
224
lf0 = pl.scan_pyarrow_dataset(ds0)
225
lf1 = pl.scan_pyarrow_dataset(ds1)
226
227
assert lf0.join(lf1, on="a", how="inner").collect().to_dict(as_series=False) == {
228
"a": [1, 2]
229
}
230
231
232
def test_pyarrow_dataset_predicate_verbose_log(
233
tmp_path: Path,
234
monkeypatch: pytest.MonkeyPatch,
235
capfd: pytest.CaptureFixture[str],
236
) -> None:
237
monkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")
238
239
df = pl.DataFrame({"a": [1, 2, 3]})
240
file_path_0 = tmp_path / "0"
241
242
df.write_parquet(file_path_0)
243
dset = ds.dataset(file_path_0, format="parquet")
244
245
q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a") < 3)
246
247
capfd.readouterr()
248
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))
249
capture = capfd.readouterr().err
250
251
assert (
252
"[SENSITIVE]: python_scan_predicate: "
253
'predicate node: [(col("a")) < (3)], '
254
"converted pyarrow predicate: (pa.compute.field('a') < 3)"
255
) in capture
256
257
q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a").cast(pl.String) < "3")
258
259
capfd.readouterr()
260
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))
261
capture = capfd.readouterr().err
262
263
assert (
264
"[SENSITIVE]: python_scan_predicate: "
265
'predicate node: [(col("a").strict_cast(String)) < ("3")], '
266
"converted pyarrow predicate: <conversion failed>\n"
267
) in capture
268
269