Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_pyarrow_dataset.py
8427 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timezone
4
from typing import TYPE_CHECKING
5
6
import pyarrow as pa
7
import pyarrow.dataset as ds
8
import pytest
9
10
import polars as pl
11
from polars.testing import assert_frame_equal
12
13
if TYPE_CHECKING:
14
from collections.abc import Callable
15
from pathlib import Path
16
17
from tests.conftest import PlMonkeyPatch
18
19
20
def helper_dataset_test(
21
file_path: Path,
22
query: Callable[[pl.LazyFrame], pl.LazyFrame],
23
batch_size: int | None = None,
24
n_expected: int | None = None,
25
check_predicate_pushdown: bool = False,
26
) -> None:
27
dset = ds.dataset(file_path, format="ipc")
28
q = pl.scan_ipc(file_path).pipe(query)
29
30
expected = q.collect()
31
out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect()
32
assert_frame_equal(out, expected)
33
if n_expected is not None:
34
assert len(out) == n_expected
35
36
if check_predicate_pushdown:
37
assert "FILTER" not in q.explain()
38
39
40
# @pytest.mark.write_disk()
41
def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None:
42
file_path = tmp_path / "small.ipc"
43
df.write_ipc(file_path)
44
45
helper_dataset_test(
46
file_path,
47
lambda lf: lf.filter("bools").select("bools", "floats", "date"),
48
n_expected=1,
49
check_predicate_pushdown=True,
50
)
51
helper_dataset_test(
52
file_path,
53
lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"),
54
n_expected=2,
55
check_predicate_pushdown=True,
56
)
57
helper_dataset_test(
58
file_path,
59
lambda lf: lf.filter(pl.col("int_nulls").is_null()).select(
60
"bools", "floats", "date"
61
),
62
n_expected=1,
63
check_predicate_pushdown=True,
64
)
65
helper_dataset_test(
66
file_path,
67
lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select(
68
"bools", "floats", "date"
69
),
70
n_expected=2,
71
check_predicate_pushdown=True,
72
)
73
helper_dataset_test(
74
file_path,
75
lambda lf: lf.filter(
76
pl.col("int_nulls").is_not_null() == pl.col("bools")
77
).select("bools", "floats", "date"),
78
n_expected=0,
79
check_predicate_pushdown=True,
80
)
81
# this equality on a column with nulls fails as pyarrow has different
82
# handling kleene logic. We leave it for now and document it in the function.
83
helper_dataset_test(
84
file_path,
85
lambda lf: lf.filter(pl.col("int") == 10).select(
86
"bools", "floats", "int_nulls"
87
),
88
n_expected=0,
89
check_predicate_pushdown=True,
90
)
91
helper_dataset_test(
92
file_path,
93
lambda lf: lf.filter(pl.col("int") != 10).select(
94
"bools", "floats", "int_nulls"
95
),
96
n_expected=3,
97
check_predicate_pushdown=True,
98
)
99
100
for closed, n_expected in zip(
101
["both", "left", "right", "none"], [3, 2, 2, 1], strict=True
102
):
103
helper_dataset_test(
104
file_path,
105
lambda lf, closed=closed: lf.filter( # type: ignore[misc]
106
pl.col("int").is_between(1, 3, closed=closed)
107
).select("bools", "floats", "date"),
108
n_expected=n_expected,
109
check_predicate_pushdown=True,
110
)
111
# this predicate is not supported by pyarrow
112
# check if we still do it on our side
113
helper_dataset_test(
114
file_path,
115
lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select(
116
"bools", "floats", "date"
117
),
118
n_expected=0,
119
)
120
# temporal types
121
helper_dataset_test(
122
file_path,
123
lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select(
124
"bools", "floats", "date"
125
),
126
n_expected=1,
127
check_predicate_pushdown=True,
128
)
129
helper_dataset_test(
130
file_path,
131
lambda lf: lf.filter(
132
pl.col("datetime") > datetime(1970, 1, 1, second=13)
133
).select("bools", "floats", "date"),
134
n_expected=1,
135
check_predicate_pushdown=True,
136
)
137
# not yet supported in pyarrow
138
helper_dataset_test(
139
file_path,
140
lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select(
141
"bools", "time", "date"
142
),
143
n_expected=3,
144
check_predicate_pushdown=True,
145
)
146
# pushdown is_in
147
helper_dataset_test(
148
file_path,
149
lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select(
150
"bools", "floats", "date"
151
),
152
n_expected=2,
153
check_predicate_pushdown=True,
154
)
155
helper_dataset_test(
156
file_path,
157
lambda lf: lf.filter(
158
pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)])
159
).select("bools", "floats", "date"),
160
n_expected=2,
161
check_predicate_pushdown=True,
162
)
163
helper_dataset_test(
164
file_path,
165
lambda lf: lf.filter(
166
pl.col("datetime").is_in(
167
[
168
datetime(1970, 1, 1, 0, 0, 12, 341234),
169
datetime(1970, 1, 1, 0, 0, 13, 241324),
170
]
171
)
172
).select("bools", "floats", "date"),
173
n_expected=2,
174
check_predicate_pushdown=True,
175
)
176
helper_dataset_test(
177
file_path,
178
lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select(
179
"bools", "floats", "date"
180
),
181
n_expected=3,
182
check_predicate_pushdown=True,
183
)
184
helper_dataset_test(
185
file_path,
186
lambda lf: lf.filter(pl.col("cat").is_in([])).select("bools", "floats", "date"),
187
n_expected=0,
188
)
189
helper_dataset_test(
190
file_path,
191
lambda lf: lf.select(pl.exclude("enum")),
192
batch_size=2,
193
n_expected=3,
194
)
195
196
# direct filter
197
helper_dataset_test(
198
file_path,
199
lambda lf: lf.filter(pl.Series([True, False, True])).select(
200
"bools", "floats", "date"
201
),
202
n_expected=2,
203
)
204
205
helper_dataset_test(
206
file_path,
207
lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select(
208
"bools", "floats"
209
),
210
n_expected=1,
211
check_predicate_pushdown=True,
212
)
213
214
215
def test_pyarrow_dataset_partial_predicate_pushdown(
216
tmp_path: Path,
217
plmonkeypatch: PlMonkeyPatch,
218
capfd: pytest.CaptureFixture[str],
219
) -> None:
220
plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")
221
222
df = pl.DataFrame({"a": [1, 2, 3], "b": [10.0, 20.0, 30.0]})
223
file_path = tmp_path / "0"
224
df.write_parquet(file_path)
225
dset = ds.dataset(file_path, format="parquet")
226
227
# col("a") > 1 is convertible; col("a") * col("b") > 25 is not (arithmetic
228
# on two columns cannot be expressed as a pyarrow compute expression).
229
# The optimizer pushes both terms into the scan's SELECTION, so our
230
# MintermIter-based partial conversion should push the convertible part.
231
q = pl.scan_pyarrow_dataset(dset).filter(
232
(pl.col("a") > 1) & (pl.col("a") * pl.col("b") > 25)
233
)
234
235
capfd.readouterr()
236
result = q.collect()
237
capture = capfd.readouterr().err
238
239
# Verify: partial predicate was pushed to pyarrow
240
assert "(pa.compute.field('a') > 1)" in capture
241
assert (
242
'residual predicate: Some([([(col("a").cast(Float64)) * (col("b"))]) > (25.0)])'
243
in capture
244
)
245
# Verify: correctness
246
expected = (
247
df.lazy().filter((pl.col("a") > 1) & (pl.col("a") * pl.col("b") > 25)).collect()
248
)
249
assert_frame_equal(result, expected)
250
251
252
def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None:
253
df0 = pl.DataFrame({"a": [1, 2, 3]})
254
255
df1 = pl.DataFrame({"a": [1, 2]})
256
257
file_path_0 = tmp_path / "0.parquet"
258
file_path_1 = tmp_path / "1.parquet"
259
260
df0.write_parquet(file_path_0)
261
df1.write_parquet(file_path_1)
262
263
ds0 = ds.dataset(file_path_0, format="parquet")
264
ds1 = ds.dataset(file_path_1, format="parquet")
265
266
lf0 = pl.scan_pyarrow_dataset(ds0)
267
lf1 = pl.scan_pyarrow_dataset(ds1)
268
269
assert_frame_equal(
270
lf0.join(lf1, on="a", how="inner").collect(),
271
pl.DataFrame({"a": [1, 2]}),
272
check_row_order=False,
273
)
274
275
276
def test_pyarrow_dataset_predicate_verbose_log(
277
tmp_path: Path,
278
plmonkeypatch: PlMonkeyPatch,
279
capfd: pytest.CaptureFixture[str],
280
) -> None:
281
plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")
282
283
df = pl.DataFrame({"a": [1, 2, 3]})
284
file_path_0 = tmp_path / "0"
285
286
df.write_parquet(file_path_0)
287
dset = ds.dataset(file_path_0, format="parquet")
288
289
q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a") < 3)
290
291
capfd.readouterr()
292
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))
293
capture = capfd.readouterr().err
294
295
assert (
296
"[SENSITIVE]: python_scan_predicate: "
297
'predicate node: [(col("a")) < (3)], '
298
"converted pyarrow predicate: (pa.compute.field('a') < 3), "
299
"residual predicate: None"
300
) in capture
301
302
q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a").cast(pl.String) < "3")
303
304
capfd.readouterr()
305
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))
306
capture = capfd.readouterr().err
307
308
assert (
309
"[SENSITIVE]: python_scan_predicate: "
310
'predicate node: [(col("a").strict_cast(String)) < ("3")], '
311
"converted pyarrow predicate: <conversion failed>, "
312
'residual predicate: Some([(col("a").strict_cast(String)) < ("3")])'
313
) in capture
314
315
316
@pytest.mark.write_disk
317
def test_pyarrow_dataset_python_scan(tmp_path: Path) -> None:
318
df = pl.DataFrame({"x": [0, 1, 2, 3]})
319
file_path = tmp_path / "0.parquet"
320
df.write_parquet(file_path)
321
322
dataset = ds.dataset(file_path)
323
lf = pl.scan_pyarrow_dataset(dataset)
324
out = lf.collect(engine="streaming")
325
326
assert_frame_equal(df, out)
327
328
329
def test_pyarrow_dataset_allow_pyarrow_filter_false() -> None:
330
df = pl.DataFrame({"item": ["foo", "bar", "baz"], "price": [10.0, 20.0, 30.0]})
331
dataset = ds.dataset(df.to_arrow(compat_level=pl.CompatLevel.oldest()))
332
333
# basic scan without filter
334
result = pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False).collect()
335
assert_frame_equal(result, df)
336
337
# with filter (predicate should be applied by Polars, not PyArrow)
338
result = (
339
pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False)
340
.filter(pl.col("price") > 15)
341
.collect()
342
)
343
344
expected = pl.DataFrame({"item": ["bar", "baz"], "price": [20.0, 30.0]})
345
assert_frame_equal(result, expected)
346
347
# check user-specified `batch_size` doesn't error (ref: #25316)
348
result = (
349
pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False, batch_size=1000)
350
.filter(pl.col("price") > 15)
351
.collect()
352
)
353
assert_frame_equal(result, expected)
354
355
# check `allow_pyarrow_filter=True` still works
356
result = (
357
pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=True)
358
.filter(pl.col("price") > 15)
359
.collect()
360
)
361
assert_frame_equal(result, expected)
362
363
364
def test_scan_pyarrow_dataset_filter_with_timezone_26029() -> None:
365
table = pa.table(
366
{
367
"valid_from": [
368
datetime(2025, 8, 26, 10, 0, 0, tzinfo=timezone.utc),
369
datetime(2025, 8, 26, 11, 0, 0, tzinfo=timezone.utc),
370
],
371
"valid_to": [
372
datetime(2025, 8, 26, 12, 0, 0, tzinfo=timezone.utc),
373
datetime(2025, 8, 26, 13, 0, 0, tzinfo=timezone.utc),
374
],
375
"value": [1, 2],
376
}
377
)
378
dataset = ds.dataset(table)
379
380
lower_bound_time = datetime(2025, 8, 26, 11, 30, 0, tzinfo=timezone.utc)
381
lf = pl.scan_pyarrow_dataset(dataset).filter(
382
(pl.col("valid_from") <= lower_bound_time)
383
& (pl.col("valid_to") > lower_bound_time)
384
)
385
386
assert_frame_equal(lf.collect(), pl.DataFrame(table))
387
388
389
def test_scan_pyarrow_dataset_filter_slice_order() -> None:
390
table = pa.table(
391
{
392
"index": [0, 1, 2],
393
"year": [2025, 2026, 2026],
394
"month": [0, 0, 0],
395
}
396
)
397
dataset = ds.dataset(table)
398
399
q = pl.scan_pyarrow_dataset(dataset).head(2).filter(pl.col("year") == 2026)
400
401
assert_frame_equal(
402
q.collect(),
403
pl.DataFrame({"index": 1, "year": 2026, "month": 0}),
404
)
405
406
import polars.io.pyarrow_dataset.anonymous_scan
407
408
assert_frame_equal(
409
polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(
410
dataset,
411
n_rows=2,
412
predicate="pa.compute.field('year') == 2026",
413
with_columns=None,
414
),
415
pl.DataFrame({"index": 1, "year": 2026, "month": 0}),
416
)
417
418
assert_frame_equal(
419
polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(
420
dataset,
421
n_rows=0,
422
predicate="pa.compute.field('year') == 2026",
423
with_columns=None,
424
),
425
pl.DataFrame(schema={"index": pl.Int64, "year": pl.Int64, "month": pl.Int64}),
426
)
427
428
assert_frame_equal(
429
pl.concat(
430
polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(
431
dataset,
432
n_rows=1,
433
predicate=None,
434
with_columns=None,
435
allow_pyarrow_filter=False,
436
)[0]
437
),
438
pl.DataFrame({"index": 0, "year": 2025, "month": 0}),
439
)
440
441
assert not polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(
442
dataset,
443
n_rows=0,
444
predicate="pa.compute.field('year') == 2026",
445
with_columns=None,
446
allow_pyarrow_filter=False,
447
)[1]
448
449