Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_hive.py
8425 views
1
from __future__ import annotations
2
3
import sys
4
import urllib.parse
5
import warnings
6
from collections import OrderedDict
7
from datetime import date, datetime
8
from functools import partial
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any
11
12
import pyarrow.parquet as pq
13
import pytest
14
15
import polars as pl
16
from polars.exceptions import ComputeError, SchemaFieldNotFoundError
17
from polars.testing import assert_frame_equal, assert_series_equal
18
from tests.unit.io.conftest import format_file_uri
19
20
if TYPE_CHECKING:
21
from collections.abc import Callable
22
23
from tests.conftest import PlMonkeyPatch
24
25
26
def impl_test_hive_partitioned_predicate_pushdown(
27
io_files_path: Path,
28
tmp_path: Path,
29
plmonkeypatch: PlMonkeyPatch,
30
) -> None:
31
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
32
df = pl.read_ipc(io_files_path / "*.ipc")
33
34
root = tmp_path / "partitioned_data"
35
36
pq.write_to_dataset(
37
df.to_arrow(),
38
root_path=root,
39
partition_cols=["category", "fats_g"],
40
)
41
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=False)
42
# checks schema
43
assert q.collect_schema().names() == ["calories", "sugars_g"]
44
# checks materialization
45
assert q.collect().columns == ["calories", "sugars_g"]
46
47
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
48
assert q.collect_schema().names() == ["calories", "sugars_g", "category", "fats_g"]
49
50
# Partitioning changes the order
51
sort_by = ["fats_g", "category", "calories", "sugars_g"]
52
53
# The hive partitioned columns are appended,
54
# so we must ensure we assert in the proper order.
55
df = df.select(["calories", "sugars_g", "category", "fats_g"])
56
for streaming in [True, False]:
57
for pred in [
58
pl.col("category") == "vegetables",
59
pl.col("category") != "vegetables",
60
pl.col("fats_g") > 0.5,
61
(pl.col("fats_g") == 0.5) & (pl.col("category") == "vegetables"),
62
]:
63
assert_frame_equal(
64
q.filter(pred)
65
.sort(sort_by)
66
.collect(engine="streaming" if streaming else "in-memory"),
67
df.filter(pred).sort(sort_by),
68
)
69
70
# tests: 11536
71
assert q.filter(pl.col("sugars_g") == 25).collect().shape == (1, 4)
72
73
# tests: 12570
74
assert q.filter(pl.col("fats_g") == 1225.0).select("category").collect().shape == (
75
0,
76
1,
77
)
78
79
80
@pytest.mark.xdist_group("streaming")
81
@pytest.mark.write_disk
82
def test_hive_partitioned_predicate_pushdown(
83
io_files_path: Path,
84
tmp_path: Path,
85
plmonkeypatch: PlMonkeyPatch,
86
) -> None:
87
impl_test_hive_partitioned_predicate_pushdown(
88
io_files_path,
89
tmp_path,
90
plmonkeypatch,
91
)
92
93
94
@pytest.mark.xdist_group("streaming")
95
@pytest.mark.write_disk
96
def test_hive_partitioned_predicate_pushdown_single_threaded_async_17155(
97
io_files_path: Path,
98
tmp_path: Path,
99
plmonkeypatch: PlMonkeyPatch,
100
) -> None:
101
plmonkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
102
plmonkeypatch.setenv("POLARS_PREFETCH_SIZE", "1")
103
104
impl_test_hive_partitioned_predicate_pushdown(
105
io_files_path,
106
tmp_path,
107
plmonkeypatch,
108
)
109
110
111
@pytest.mark.write_disk
112
@pytest.mark.may_fail_auto_streaming
113
@pytest.mark.may_fail_cloud # reason: inspects logs
114
def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files(
115
tmp_path: Path, plmonkeypatch: PlMonkeyPatch, capfd: Any
116
) -> None:
117
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
118
df = pl.DataFrame({"d": pl.arange(0, 5, eager=True)}).with_columns(
119
a=pl.col("d") % 5
120
)
121
root = tmp_path / "test_int_partitions"
122
df.write_parquet(
123
root,
124
use_pyarrow=True,
125
pyarrow_options={"partition_cols": ["a"]},
126
)
127
128
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
129
assert q.filter(pl.col("a").is_in([1, 4])).collect().shape == (2, 2)
130
assert "allows skipping 3 / 5" in capfd.readouterr().err
131
132
# Ensure the CSE can work with hive partitions.
133
q = q.filter(pl.col("a").gt(2))
134
result = q.join(q, on="a", how="left").collect(
135
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
136
)
137
expected = {
138
"a": [3, 4],
139
"d": [3, 4],
140
"d_right": [3, 4],
141
}
142
assert result.to_dict(as_series=False) == expected
143
144
145
@pytest.mark.write_disk
146
def test_hive_streaming_pushdown_is_in_22212(tmp_path: Path) -> None:
147
(
148
pl.DataFrame({"x": range(5)}).write_parquet(
149
tmp_path,
150
partition_by="x",
151
)
152
)
153
154
lf = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(
155
pl.col("x").is_in([1, 4])
156
)
157
158
assert_frame_equal(
159
lf.collect(
160
engine="streaming", optimizations=pl.QueryOptFlags(predicate_pushdown=False)
161
),
162
lf.collect(
163
engine="streaming", optimizations=pl.QueryOptFlags(predicate_pushdown=True)
164
),
165
)
166
167
168
@pytest.mark.xdist_group("streaming")
169
@pytest.mark.write_disk
170
@pytest.mark.parametrize("streaming", [True, False])
171
def test_hive_partitioned_slice_pushdown(
172
io_files_path: Path, tmp_path: Path, streaming: bool
173
) -> None:
174
df = pl.read_ipc(io_files_path / "*.ipc")
175
176
root = tmp_path / "partitioned_data"
177
178
# Ignore the pyarrow legacy warning until we can write properly with new settings.
179
warnings.filterwarnings("ignore")
180
pq.write_to_dataset(
181
df.to_arrow(),
182
root_path=root,
183
partition_cols=["category", "fats_g"],
184
)
185
186
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
187
schema = q.collect_schema()
188
expect_count = pl.select(pl.lit(1, dtype=pl.UInt32).alias(x) for x in schema)
189
190
assert_frame_equal(
191
q.head(1)
192
.collect(engine="streaming" if streaming else "in-memory")
193
.select(pl.all().len()),
194
expect_count,
195
)
196
assert q.head(0).collect(
197
engine="streaming" if streaming else "in-memory"
198
).columns == [
199
"calories",
200
"sugars_g",
201
"category",
202
"fats_g",
203
]
204
205
206
@pytest.mark.xdist_group("streaming")
207
@pytest.mark.write_disk
208
def test_hive_partitioned_projection_pushdown(
209
io_files_path: Path, tmp_path: Path
210
) -> None:
211
df = pl.read_ipc(io_files_path / "*.ipc")
212
213
root = tmp_path / "partitioned_data"
214
215
# Ignore the pyarrow legacy warning until we can write properly with new settings.
216
warnings.filterwarnings("ignore")
217
pq.write_to_dataset(
218
df.to_arrow(),
219
root_path=root,
220
partition_cols=["category", "fats_g"],
221
)
222
223
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
224
columns = ["sugars_g", "category"]
225
for streaming in [True, False]:
226
assert (
227
q.select(columns)
228
.collect(engine="streaming" if streaming else "in-memory")
229
.columns
230
== columns
231
)
232
233
# test that hive partition columns are projected with the correct height when
234
# the projection contains only hive partition columns (11796)
235
for parallel in ("row_groups", "columns"):
236
q = pl.scan_parquet(
237
root / "**/*.parquet",
238
hive_partitioning=True,
239
parallel=parallel,
240
)
241
242
expected = q.collect().select("category")
243
result = q.select("category").collect()
244
245
assert_frame_equal(result, expected)
246
247
248
@pytest.mark.write_disk
249
def test_hive_partitioned_projection_skips_files(tmp_path: Path) -> None:
250
# ensure that it makes hive columns even when . in dir value
251
# and that it doesn't make hive columns from filename with =
252
df = pl.DataFrame(
253
{"sqlver": [10012.0, 10013.0], "namespace": ["eos", "fda"], "a": [1, 2]}
254
)
255
root = tmp_path / "partitioned_data"
256
for dir_tuple, sub_df in df.partition_by(
257
["sqlver", "namespace"], include_key=False, as_dict=True
258
).items():
259
new_path = root / f"sqlver={dir_tuple[0]}" / f"namespace={dir_tuple[1]}"
260
new_path.mkdir(parents=True, exist_ok=True)
261
sub_df.write_parquet(new_path / "file=8484.parquet")
262
test_df = (
263
pl.scan_parquet(str(root) + "/**/**/*.parquet", hive_partitioning=True)
264
# don't care about column order
265
.select("sqlver", "namespace", "a", pl.exclude("sqlver", "namespace", "a"))
266
.collect()
267
)
268
assert_frame_equal(df, test_df)
269
270
271
@pytest.fixture
272
def dataset_path(tmp_path: Path) -> Path:
273
tmp_path.mkdir(exist_ok=True)
274
275
# Set up Hive partitioned Parquet file
276
root = tmp_path / "dataset"
277
part1 = root / "c=1"
278
part2 = root / "c=2"
279
root.mkdir()
280
part1.mkdir()
281
part2.mkdir()
282
df1 = pl.DataFrame({"a": [1, 2], "b": [11.0, 12.0]})
283
df2 = pl.DataFrame({"a": [3, 4], "b": [13.0, 14.0]})
284
df3 = pl.DataFrame({"a": [5, 6], "b": [15.0, 16.0]})
285
df1.write_parquet(part1 / "one.parquet")
286
df2.write_parquet(part1 / "two.parquet")
287
df3.write_parquet(part2 / "three.parquet")
288
289
return root
290
291
292
@pytest.mark.write_disk
293
def test_scan_parquet_hive_schema(dataset_path: Path) -> None:
294
result = pl.scan_parquet(dataset_path / "**/*.parquet", hive_partitioning=True)
295
assert result.collect_schema() == OrderedDict(
296
{"a": pl.Int64, "b": pl.Float64, "c": pl.Int64}
297
)
298
299
result = pl.scan_parquet(
300
dataset_path / "**/*.parquet",
301
hive_partitioning=True,
302
hive_schema={"c": pl.Int32},
303
)
304
305
expected_schema = OrderedDict({"a": pl.Int64, "b": pl.Float64, "c": pl.Int32})
306
assert result.collect_schema() == expected_schema
307
assert result.collect().schema == expected_schema
308
309
310
@pytest.mark.write_disk
311
def test_read_parquet_invalid_hive_schema(dataset_path: Path) -> None:
312
with pytest.raises(
313
SchemaFieldNotFoundError,
314
match='path contains column not present in the given Hive schema: "c"',
315
):
316
pl.read_parquet(
317
dataset_path / "**/*.parquet",
318
hive_partitioning=True,
319
hive_schema={"nonexistent": pl.Int32},
320
)
321
322
323
def test_read_parquet_hive_schema_with_pyarrow(tmp_path: Path) -> None:
324
with pytest.raises(
325
TypeError,
326
match="cannot use `hive_partitions` with `use_pyarrow=True`",
327
):
328
pl.read_parquet(
329
tmp_path / "test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True
330
)
331
332
333
@pytest.mark.parametrize(
334
("scan_func", "write_func"),
335
[
336
(pl.scan_parquet, pl.DataFrame.write_parquet),
337
(pl.scan_ipc, pl.DataFrame.write_ipc),
338
],
339
)
340
@pytest.mark.parametrize(
341
"glob",
342
[True, False],
343
)
344
def test_hive_partition_directory_scan(
345
tmp_path: Path,
346
scan_func: Callable[..., pl.LazyFrame],
347
write_func: Callable[[pl.DataFrame, Path], None],
348
glob: bool,
349
) -> None:
350
tmp_path.mkdir(exist_ok=True)
351
352
dfs = [
353
pl.DataFrame({'x': 5 * [1], 'a': 1, 'b': 1}),
354
pl.DataFrame({'x': 5 * [2], 'a': 1, 'b': 2}),
355
pl.DataFrame({'x': 5 * [3], 'a': 22, 'b': 1}),
356
pl.DataFrame({'x': 5 * [4], 'a': 22, 'b': 2}),
357
] # fmt: skip
358
359
for df in dfs:
360
a = df.item(0, "a")
361
b = df.item(0, "b")
362
path = tmp_path / f"a={a}/b={b}/data.bin"
363
path.parent.mkdir(exist_ok=True, parents=True)
364
write_func(df.drop("a", "b"), path)
365
366
df = pl.concat(dfs)
367
hive_schema = df.lazy().select("a", "b").collect_schema()
368
369
scan = scan_func
370
371
if scan_func is pl.scan_parquet:
372
scan = partial(scan, glob=glob)
373
374
scan_with_hive_schema = partial(scan_func, hive_schema=hive_schema)
375
376
out = scan_with_hive_schema(
377
tmp_path,
378
hive_partitioning=True,
379
).collect()
380
assert_frame_equal(out, df)
381
382
out = scan(tmp_path, hive_partitioning=False).collect()
383
assert_frame_equal(out, df.drop("a", "b"))
384
385
out = scan_with_hive_schema(
386
tmp_path / "a=1",
387
hive_partitioning=True,
388
).collect()
389
assert_frame_equal(out, df.filter(a=1).drop("a"))
390
391
out = scan(tmp_path / "a=1", hive_partitioning=False).collect()
392
assert_frame_equal(out, df.filter(a=1).drop("a", "b"))
393
394
path = tmp_path / "a=1/b=1/data.bin"
395
396
out = scan_with_hive_schema(path, hive_partitioning=True).collect()
397
assert_frame_equal(out, dfs[0])
398
399
out = scan(path, hive_partitioning=False).collect()
400
assert_frame_equal(out, dfs[0].drop("a", "b"))
401
402
# Test default behavior with `hive_partitioning=None`, which should only
403
# enable hive partitioning when a single directory is passed:
404
out = scan_with_hive_schema(tmp_path).collect()
405
assert_frame_equal(out, df)
406
407
# Otherwise, hive partitioning is not enabled automatically:
408
out = scan(tmp_path / "a=1/b=1/data.bin").collect()
409
assert out.columns == ["x"]
410
411
out = scan([tmp_path / "a=1/", tmp_path / "a=22/"]).collect()
412
assert out.columns == ["x"]
413
414
out = scan([tmp_path / "a=1/", tmp_path / "a=22/b=1/data.bin"]).collect()
415
assert out.columns == ["x"]
416
417
if glob:
418
out = scan(tmp_path / "a=1/**/*.bin").collect()
419
assert out.columns == ["x"]
420
421
# Test `hive_partitioning=True`
422
out = scan_with_hive_schema(tmp_path, hive_partitioning=True).collect()
423
assert_frame_equal(out, df)
424
425
# Accept multiple directories from the same level
426
out = scan_with_hive_schema(
427
[tmp_path / "a=1", tmp_path / "a=22"], hive_partitioning=True
428
).collect()
429
assert_frame_equal(out, df.drop("a"))
430
431
with pytest.raises(
432
pl.exceptions.InvalidOperationError,
433
match="attempted to read from different directory levels with hive partitioning enabled:",
434
):
435
scan_with_hive_schema(
436
[tmp_path / "a=1", tmp_path / "a=22/b=1"], hive_partitioning=True
437
).collect()
438
439
if glob:
440
out = scan_with_hive_schema(
441
tmp_path / "**/*.bin", hive_partitioning=True
442
).collect()
443
assert_frame_equal(out, df)
444
445
# Parse hive from full path for glob patterns
446
out = scan_with_hive_schema(
447
[tmp_path / "a=1/**/*.bin", tmp_path / "a=22/**/*.bin"],
448
hive_partitioning=True,
449
).collect()
450
assert_frame_equal(out, df)
451
452
# Parse hive from full path for files
453
out = scan_with_hive_schema(
454
tmp_path / "a=1/b=1/data.bin", hive_partitioning=True
455
).collect()
456
assert_frame_equal(out, df.filter(a=1, b=1))
457
458
out = scan_with_hive_schema(
459
[tmp_path / "a=1/b=1/data.bin", tmp_path / "a=22/b=1/data.bin"],
460
hive_partitioning=True,
461
).collect()
462
assert_frame_equal(
463
out,
464
df.filter(
465
((pl.col("a") == 1) & (pl.col("b") == 1))
466
| ((pl.col("a") == 22) & (pl.col("b") == 1))
467
),
468
)
469
470
# Test `hive_partitioning=False`
471
out = scan(tmp_path, hive_partitioning=False).collect()
472
assert_frame_equal(out, df.drop("a", "b"))
473
474
if glob:
475
out = scan(tmp_path / "**/*.bin", hive_partitioning=False).collect()
476
assert_frame_equal(out, df.drop("a", "b"))
477
478
out = scan(tmp_path / "a=1/b=1/data.bin", hive_partitioning=False).collect()
479
assert_frame_equal(out, df.filter(a=1, b=1).drop("a", "b"))
480
481
482
def test_hive_partition_schema_inference(tmp_path: Path) -> None:
483
tmp_path.mkdir(exist_ok=True)
484
485
dfs = [
486
pl.DataFrame({"x": 1}),
487
pl.DataFrame({"x": 2}),
488
pl.DataFrame({"x": 3}),
489
]
490
491
paths = [
492
tmp_path / "a=1/data.bin",
493
tmp_path / "a=1.5/data.bin",
494
tmp_path / "a=polars/data.bin",
495
]
496
497
expected = [
498
pl.Series("a", [1], dtype=pl.Int64),
499
pl.Series("a", [1.0, 1.5], dtype=pl.Float64),
500
pl.Series("a", ["1", "1.5", "polars"], dtype=pl.String),
501
]
502
503
for i in range(3):
504
paths[i].parent.mkdir(exist_ok=True, parents=True)
505
dfs[i].write_parquet(paths[i])
506
out = pl.scan_parquet(tmp_path).sort("x").collect()
507
508
assert_series_equal(out["a"], expected[i])
509
510
511
@pytest.mark.write_disk
512
def test_hive_partition_force_async_17155(
513
tmp_path: Path, plmonkeypatch: PlMonkeyPatch
514
) -> None:
515
plmonkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
516
plmonkeypatch.setenv("POLARS_PREFETCH_SIZE", "1")
517
518
dfs = [
519
pl.DataFrame({"x": 1}),
520
pl.DataFrame({"x": 2}),
521
pl.DataFrame({"x": 3}),
522
]
523
524
paths = [
525
tmp_path / "a=1/b=1/data.bin",
526
tmp_path / "a=2/b=2/data.bin",
527
tmp_path / "a=3/b=3/data.bin",
528
]
529
530
for i in range(3):
531
paths[i].parent.mkdir(exist_ok=True, parents=True)
532
dfs[i].write_parquet(paths[i])
533
534
lf = pl.scan_parquet(tmp_path)
535
536
assert_frame_equal(
537
lf.collect(), pl.DataFrame({k: [1, 2, 3] for k in ["x", "a", "b"]})
538
)
539
540
541
@pytest.mark.parametrize(
542
("scan_func", "write_func"),
543
[
544
(partial(pl.scan_parquet, parallel="row_groups"), pl.DataFrame.write_parquet),
545
(partial(pl.scan_parquet, parallel="columns"), pl.DataFrame.write_parquet),
546
(partial(pl.scan_parquet, parallel="prefiltered"), pl.DataFrame.write_parquet),
547
(
548
lambda *a, **kw: pl.scan_parquet(*a, parallel="prefiltered", **kw).filter(
549
pl.col("b") == pl.col("b")
550
),
551
pl.DataFrame.write_parquet,
552
),
553
(pl.scan_ipc, pl.DataFrame.write_ipc),
554
],
555
)
556
@pytest.mark.write_disk
557
@pytest.mark.slow
558
@pytest.mark.parametrize("projection_pushdown", [True, False])
559
def test_hive_partition_columns_contained_in_file(
560
tmp_path: Path,
561
scan_func: Callable[[Any], pl.LazyFrame],
562
write_func: Callable[[pl.DataFrame, Path], None],
563
projection_pushdown: bool,
564
) -> None:
565
path = tmp_path / "a=1/b=2/data.bin"
566
path.parent.mkdir(exist_ok=True, parents=True)
567
df = pl.DataFrame(
568
{"x": 1, "a": 1, "b": 2, "y": 1},
569
schema={"x": pl.Int32, "a": pl.Int8, "b": pl.Int16, "y": pl.Int32},
570
)
571
write_func(df, path)
572
573
def assert_with_projections(
574
lf: pl.LazyFrame, df: pl.DataFrame, *, row_index: str | None = None
575
) -> None:
576
row_index: list[str] = [row_index] if row_index is not None else [] # type: ignore[no-redef]
577
578
from itertools import permutations
579
580
cols = ["a", "b", "x", "y", *row_index] # type: ignore[misc]
581
582
for projection in (
583
x for i in range(len(cols)) for x in permutations(cols[: 1 + i])
584
):
585
assert_frame_equal(
586
lf.select(projection).collect(
587
optimizations=pl.QueryOptFlags(
588
projection_pushdown=projection_pushdown
589
)
590
),
591
df.select(projection),
592
)
593
594
lf = scan_func(path, hive_partitioning=True) # type: ignore[call-arg]
595
rhs = df
596
assert_frame_equal(
597
lf.collect(
598
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
599
),
600
rhs,
601
)
602
assert_with_projections(lf, rhs)
603
604
lf = scan_func( # type: ignore[call-arg]
605
path,
606
hive_schema={"a": pl.String, "b": pl.String},
607
hive_partitioning=True,
608
)
609
rhs = df.with_columns(pl.col("a", "b").cast(pl.String))
610
assert_frame_equal(
611
lf.collect(
612
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
613
),
614
rhs,
615
)
616
assert_with_projections(lf, rhs)
617
618
# partial cols in file
619
partial_path = tmp_path / "a=1/b=2/partial_data.bin"
620
df = pl.DataFrame(
621
{"x": 1, "b": 2, "y": 1},
622
schema={"x": pl.Int32, "b": pl.Int16, "y": pl.Int32},
623
)
624
write_func(df, partial_path)
625
626
rhs = rhs.select(
627
pl.col("x").cast(pl.Int32),
628
pl.col("b").cast(pl.Int16),
629
pl.col("y").cast(pl.Int32),
630
pl.col("a").cast(pl.Int64),
631
)
632
633
lf = scan_func(partial_path, hive_partitioning=True) # type: ignore[call-arg]
634
assert_frame_equal(
635
lf.collect(
636
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
637
),
638
rhs,
639
)
640
assert_with_projections(lf, rhs)
641
642
assert_frame_equal(
643
lf.with_row_index().collect(
644
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
645
),
646
rhs.with_row_index(),
647
)
648
assert_with_projections(
649
lf.with_row_index(), rhs.with_row_index(), row_index="index"
650
)
651
652
assert_frame_equal(
653
lf.with_row_index()
654
.select(pl.exclude("index"), "index")
655
.collect(
656
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
657
),
658
rhs.with_row_index().select(pl.exclude("index"), "index"),
659
)
660
assert_with_projections(
661
lf.with_row_index().select(pl.exclude("index"), "index"),
662
rhs.with_row_index().select(pl.exclude("index"), "index"),
663
row_index="index",
664
)
665
666
lf = scan_func( # type: ignore[call-arg]
667
partial_path,
668
hive_schema={"a": pl.String, "b": pl.String},
669
hive_partitioning=True,
670
)
671
rhs = rhs.select(
672
pl.col("x").cast(pl.Int32),
673
pl.col("b").cast(pl.String),
674
pl.col("y").cast(pl.Int32),
675
pl.col("a").cast(pl.String),
676
)
677
assert_frame_equal(
678
lf.collect(
679
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
680
),
681
rhs,
682
)
683
assert_with_projections(lf, rhs)
684
685
686
@pytest.mark.write_disk
687
def test_hive_partition_dates(tmp_path: Path) -> None:
688
df = pl.DataFrame(
689
{
690
"date1": [
691
datetime(2024, 1, 1),
692
datetime(2024, 2, 1),
693
datetime(2024, 3, 1),
694
None,
695
],
696
"date2": [
697
datetime(2023, 1, 1),
698
datetime(2023, 2, 1),
699
None,
700
datetime(2023, 3, 1),
701
],
702
"x": [1, 2, 3, 4],
703
},
704
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
705
)
706
707
root = tmp_path / "pyarrow"
708
pq.write_to_dataset(
709
df.to_arrow(),
710
root_path=root,
711
partition_cols=["date1", "date2"],
712
)
713
714
lf = pl.scan_parquet(
715
root, hive_schema=df.clear().select("date1", "date2").collect_schema()
716
)
717
assert_frame_equal(lf.collect(), df.select("x", "date1", "date2"))
718
719
lf = pl.scan_parquet(root)
720
assert_frame_equal(lf.collect(), df.select("x", "date1", "date2"))
721
722
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
723
assert_frame_equal(
724
lf.collect(),
725
df.select("x", "date1", "date2").with_columns(
726
pl.col("date1", "date2").cast(pl.String)
727
),
728
)
729
730
for perc_escape in [True, False] if sys.platform != "win32" else [True]:
731
root = tmp_path / f"includes_hive_cols_in_file_{perc_escape}"
732
for (date1, date2), part_df in df.group_by(
733
pl.col("date1").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"),
734
pl.col("date2").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"),
735
):
736
if perc_escape:
737
date2 = urllib.parse.quote(date2)
738
739
path = root / f"date1={date1}/date2={date2}/data.bin"
740
path.parent.mkdir(exist_ok=True, parents=True)
741
part_df.write_parquet(path)
742
743
# The schema for the hive columns is included in the file, so it should
744
# just work
745
lf = pl.scan_parquet(root)
746
assert_frame_equal(lf.collect(), df)
747
748
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
749
assert_frame_equal(
750
lf.collect(),
751
df.with_columns(pl.col("date1", "date2").cast(pl.String)),
752
)
753
754
755
@pytest.mark.write_disk
756
def test_hive_partition_filter_null_23005(tmp_path: Path) -> None:
757
root = tmp_path
758
759
df = pl.DataFrame(
760
{
761
"date1": [
762
datetime(2024, 1, 1),
763
datetime(2024, 2, 1),
764
datetime(2024, 3, 1),
765
None,
766
],
767
"date2": [
768
datetime(2023, 1, 1),
769
datetime(2023, 2, 1),
770
None,
771
datetime(2023, 3, 1),
772
],
773
"x": [1, 2, 3, 4],
774
},
775
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
776
)
777
778
df.write_parquet(root, partition_by=["date1", "date2"])
779
780
q = pl.scan_parquet(root, include_file_paths="path")
781
782
full = q.collect()
783
784
assert (
785
full.select(
786
(
787
pl.any_horizontal(pl.col("date1", "date2").is_null())
788
& pl.col("path").str.contains("__HIVE_DEFAULT_PARTITION__")
789
).sum()
790
).item()
791
== 2
792
)
793
794
lf = pl.scan_parquet(root).filter(pl.col("date1") == datetime(2024, 1, 1))
795
assert_frame_equal(lf.collect(), df.head(1))
796
797
798
@pytest.mark.parametrize(
799
("scan_func", "write_func"),
800
[
801
(pl.scan_parquet, pl.DataFrame.write_parquet),
802
(pl.scan_ipc, pl.DataFrame.write_ipc),
803
],
804
)
805
@pytest.mark.write_disk
806
def test_projection_only_hive_parts_gives_correct_number_of_rows(
807
tmp_path: Path,
808
scan_func: Callable[[Any], pl.LazyFrame],
809
write_func: Callable[[pl.DataFrame, Path], None],
810
) -> None:
811
# Check the number of rows projected when projecting only hive parts, which
812
# should be the same as the number of rows in the file.
813
path = tmp_path / "a=3/data.bin"
814
path.parent.mkdir(exist_ok=True, parents=True)
815
816
write_func(pl.DataFrame({"x": [1, 1, 1]}), path)
817
818
assert_frame_equal(
819
scan_func(path, hive_partitioning=True).select("a").collect(), # type: ignore[call-arg]
820
pl.DataFrame({"a": [3, 3, 3]}),
821
)
822
823
824
@pytest.mark.parametrize(
825
"df",
826
[
827
pl.select(
828
pl.Series("a", [1, 2, 3, 4], dtype=pl.Int8),
829
pl.Series("b", [1, 2, 3, 4], dtype=pl.Int8),
830
pl.Series("x", [1, 2, 3, 4]),
831
),
832
pl.select(
833
pl.Series(
834
"a",
835
[1.2981275, 2.385974035, 3.1231892749185718397510, 4.129387128949156],
836
dtype=pl.Float64,
837
),
838
pl.Series("b", ["a", "b", " / c = : ", "d"]),
839
pl.Series("x", [1, 2, 3, 4]),
840
),
841
],
842
)
843
@pytest.mark.write_disk
844
def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None:
845
root = tmp_path
846
df.write_parquet(root, partition_by=["a", "b"])
847
848
lf = pl.scan_parquet(root)
849
assert_frame_equal(lf.collect(), df)
850
851
lf = pl.scan_parquet(root, hive_schema={"a": pl.String, "b": pl.String})
852
assert_frame_equal(lf.collect(), df.with_columns(pl.col("a", "b").cast(pl.String)))
853
854
855
@pytest.mark.slow
856
@pytest.mark.write_disk
857
def test_hive_write_multiple_files(tmp_path: Path) -> None:
858
chunk_size = 1
859
n_rows = 500_000
860
df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))
861
862
root = tmp_path
863
df.write_parquet(root, partition_by="a", partition_chunk_size_bytes=chunk_size)
864
865
n_out = sum(1 for _ in (root / "a=0").iterdir())
866
assert n_out == 5
867
868
assert_frame_equal(pl.scan_parquet(root).collect(), df)
869
870
871
@pytest.mark.write_disk
872
def test_hive_write_dates(tmp_path: Path) -> None:
873
df = pl.DataFrame(
874
{
875
"date1": [
876
datetime(2024, 1, 1),
877
datetime(2024, 2, 1),
878
datetime(2024, 3, 1),
879
None,
880
],
881
"date2": [
882
datetime(2023, 1, 1),
883
datetime(2023, 2, 1),
884
None,
885
datetime(2023, 3, 1, 1, 1, 1, 1),
886
],
887
"x": [1, 2, 3, 4],
888
},
889
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
890
)
891
892
root = tmp_path
893
df.write_parquet(root, partition_by=["date1", "date2"])
894
895
lf = pl.scan_parquet(root)
896
assert_frame_equal(lf.collect(), df)
897
898
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
899
assert_frame_equal(
900
lf.collect(),
901
df.with_columns(pl.col("date1", "date2").cast(pl.String)),
902
)
903
904
905
@pytest.mark.write_disk
906
@pytest.mark.may_fail_auto_streaming
907
@pytest.mark.may_fail_cloud # reason: inspects logs
908
def test_hive_predicate_dates_14712(
909
tmp_path: Path, plmonkeypatch: PlMonkeyPatch, capfd: Any
910
) -> None:
911
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
912
pl.DataFrame({"a": [datetime(2024, 1, 1)]}).write_parquet(
913
tmp_path, partition_by="a"
914
)
915
pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect()
916
assert "allows skipping 1 / 1" in capfd.readouterr().err
917
918
919
@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths")
920
@pytest.mark.write_disk
921
@pytest.mark.parametrize("prefix", ["", "file:/", "file:///"])
922
def test_hive_windows_splits_on_forward_slashes(tmp_path: Path, prefix: str) -> None:
923
# Note: This needs to be an absolute path.
924
tmp_path = tmp_path.resolve()
925
926
d = str(tmp_path)[:2]
927
928
assert d[0].isalpha()
929
assert d[1] == ":"
930
931
path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1"
932
Path(path).mkdir(exist_ok=True, parents=True)
933
934
df = pl.DataFrame({"x": "x"})
935
df.write_parquet(f"{path}/data.parquet")
936
937
expect = pl.DataFrame(
938
[
939
s.new_from_index(0, 5)
940
for s in pl.DataFrame(
941
{
942
"x": "x",
943
"a": 1,
944
"b": 1,
945
"c": 1,
946
"d": 1,
947
"e": 1,
948
}
949
)
950
]
951
)
952
953
assert_frame_equal(
954
pl.scan_parquet(
955
[
956
f"{prefix}{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet",
957
f"{prefix}{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet",
958
f"{prefix}{tmp_path}\\a=1/b=1/c=1/d=1/**/*",
959
f"{prefix}{tmp_path}/a=1/b=1\\c=1/d=1/**/*",
960
f"{prefix}{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*",
961
],
962
hive_partitioning=True,
963
).collect(),
964
expect,
965
)
966
967
q = pl.scan_parquet("file://C:/")
968
969
with pytest.raises(
970
ComputeError, match="unsupported: non-empty hostname for 'file:' URI: 'C:'"
971
):
972
q.collect()
973
974
975
@pytest.mark.write_disk
976
def test_passing_hive_schema_with_hive_partitioning_disabled_raises(
977
tmp_path: Path,
978
) -> None:
979
with pytest.raises(
980
ComputeError,
981
match="a hive schema was given but hive_partitioning was disabled",
982
):
983
pl.scan_parquet(
984
tmp_path,
985
schema={"x": pl.Int64},
986
hive_schema={"h": pl.String},
987
hive_partitioning=False,
988
).collect()
989
990
991
@pytest.mark.write_disk
992
def test_hive_auto_enables_when_unspecified_and_hive_schema_passed(
993
tmp_path: Path,
994
) -> None:
995
tmp_path.mkdir(exist_ok=True)
996
(tmp_path / "a=1").mkdir(exist_ok=True)
997
998
pl.DataFrame({"x": 1}).write_parquet(tmp_path / "a=1/1")
999
1000
for path in [tmp_path / "a=1/1", tmp_path / "**/*"]:
1001
lf = pl.scan_parquet(path, hive_schema={"a": pl.UInt8})
1002
1003
assert_frame_equal(
1004
lf.collect(),
1005
pl.select(
1006
pl.Series("x", [1]),
1007
pl.Series("a", [1], dtype=pl.UInt8),
1008
),
1009
)
1010
1011
1012
@pytest.mark.write_disk
1013
def test_hive_file_as_uri_with_hive_start_idx_23830(
1014
tmp_path: Path,
1015
) -> None:
1016
tmp_path.mkdir(exist_ok=True)
1017
(tmp_path / "a=1").mkdir(exist_ok=True)
1018
1019
pl.DataFrame({"x": 1}).write_parquet(tmp_path / "a=1/1")
1020
1021
# ensure we have a trailing "/"
1022
uri = tmp_path.resolve().as_posix().rstrip("/") + "/"
1023
1024
lf = pl.scan_parquet(format_file_uri(uri), hive_schema={"a": pl.UInt8})
1025
1026
assert_frame_equal(
1027
lf.collect(),
1028
pl.select(
1029
pl.Series("x", [1]),
1030
pl.Series("a", [1], dtype=pl.UInt8),
1031
),
1032
)
1033
1034
# https://github.com/pola-rs/polars/issues/24506
1035
# `file:` URI with `//hostname` component omitted
1036
lf = pl.scan_parquet(f"file:{uri}", hive_schema={"a": pl.UInt8})
1037
1038
assert_frame_equal(
1039
lf.collect(),
1040
pl.select(
1041
pl.Series("x", [1]),
1042
pl.Series("a", [1], dtype=pl.UInt8),
1043
),
1044
)
1045
1046
1047
@pytest.mark.write_disk
1048
@pytest.mark.parametrize("force_single_thread", [True, False])
1049
def test_hive_parquet_prefiltered_20894_21327(
1050
tmp_path: Path, force_single_thread: bool
1051
) -> None:
1052
n_threads = 1 if force_single_thread else pl.thread_pool_size()
1053
1054
file_path = tmp_path / "date=2025-01-01/00000000.parquet"
1055
file_path.parent.mkdir(exist_ok=True, parents=True)
1056
1057
data = pl.DataFrame(
1058
{
1059
"date": [date(2025, 1, 1), date(2025, 1, 1)],
1060
"value": ["1", "2"],
1061
}
1062
)
1063
1064
data.write_parquet(file_path)
1065
1066
import base64
1067
import subprocess
1068
1069
# For security, and for Windows backslashes.
1070
scan_path_b64 = base64.b64encode(str(file_path.absolute()).encode()).decode()
1071
1072
# This is, the easiest way to control the threadpool size so that it is stable.
1073
out = subprocess.check_output(
1074
[
1075
sys.executable,
1076
"-c",
1077
f"""\
1078
import os
1079
os.environ["POLARS_MAX_THREADS"] = "{n_threads}"
1080
1081
import polars as pl
1082
import datetime
1083
import base64
1084
1085
from polars.testing import assert_frame_equal
1086
1087
assert pl.thread_pool_size() == {n_threads}
1088
1089
tmp_path = base64.b64decode("{scan_path_b64}").decode()
1090
df = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(pl.col("value") == "1").collect()
1091
# We need the str() to trigger panic on invalid state
1092
str(df)
1093
1094
assert_frame_equal(df, pl.DataFrame(
1095
[
1096
pl.Series('date', [datetime.date(2025, 1, 1)], dtype=pl.Date),
1097
pl.Series('value', ['1'], dtype=pl.String),
1098
]
1099
))
1100
1101
print("OK", end="")
1102
""",
1103
],
1104
)
1105
1106
assert out == b"OK"
1107
1108
1109
def test_hive_decode_reserved_ascii_23241(tmp_path: Path) -> None:
1110
partitioned_tbl_uri = (tmp_path / "partitioned_data").resolve()
1111
start, stop = 32, 127
1112
df = pl.DataFrame(
1113
{
1114
"a": list(range(start, stop)),
1115
"strings": [chr(i) for i in range(start, stop)],
1116
}
1117
)
1118
df.write_delta(partitioned_tbl_uri, delta_write_options={"partition_by": "strings"})
1119
out = pl.read_delta(str(partitioned_tbl_uri)).sort("a").select(pl.col("strings"))
1120
1121
assert_frame_equal(df.sort(by=pl.col("a")).select(pl.col("strings")), out)
1122
1123
1124
def test_hive_decode_utf8_23241(tmp_path: Path) -> None:
1125
df = pl.DataFrame(
1126
{
1127
"strings": [
1128
"Türkiye And Egpyt",
1129
"résumé père forêt Noël",
1130
"😊",
1131
"北极熊", # a polar bear perhaps ?!
1132
],
1133
"a": [10, 20, 30, 40],
1134
}
1135
)
1136
partitioned_tbl_uri = (tmp_path / "partitioned_data").resolve()
1137
df.write_delta(partitioned_tbl_uri, delta_write_options={"partition_by": "strings"})
1138
out = pl.read_delta(str(partitioned_tbl_uri)).sort("a").select(pl.col("strings"))
1139
1140
assert_frame_equal(df.sort(by=pl.col("a")).select(pl.col("strings")), out)
1141
1142
1143
@pytest.mark.write_disk
1144
def test_hive_filter_lit_true_24235(tmp_path: Path) -> None:
1145
df = pl.DataFrame({"p": [1, 2, 3, 4, 5], "x": [1, 1, 2, 2, 3]})
1146
df.lazy().sink_parquet(pl.PartitionBy(tmp_path, key="p"), mkdir=True)
1147
1148
assert_frame_equal(
1149
pl.scan_parquet(tmp_path).filter(True).collect(),
1150
df,
1151
)
1152
1153
assert_frame_equal(
1154
pl.scan_parquet(tmp_path).filter(pl.lit(True)).collect(),
1155
df,
1156
)
1157
1158
assert_frame_equal(
1159
pl.scan_parquet(tmp_path).filter(False).collect(),
1160
df.clear(),
1161
)
1162
1163
assert_frame_equal(
1164
pl.scan_parquet(tmp_path).filter(pl.lit(False)).collect(),
1165
df.clear(),
1166
)
1167
1168
1169
def test_hive_filter_in_ir(
1170
tmp_path: Path, plmonkeypatch: PlMonkeyPatch, capfd: pytest.CaptureFixture[str]
1171
) -> None:
1172
(tmp_path / "a=1").mkdir()
1173
pl.DataFrame({"x": [0, 1, 2, 3, 4]}).write_parquet(tmp_path / "a=1/data.parquet")
1174
(tmp_path / "a=2").mkdir()
1175
pl.DataFrame({"x": [5, 6, 7, 8, 9]}).write_parquet(tmp_path / "a=2/data.parquet")
1176
1177
with plmonkeypatch.context() as cx:
1178
cx.setenv("POLARS_VERBOSE", "1")
1179
1180
capfd.readouterr()
1181
1182
assert_frame_equal(
1183
pl.scan_parquet(tmp_path).filter(pl.col("a") == 1).collect(),
1184
pl.DataFrame({"x": [0, 1, 2, 3, 4], "a": [1, 1, 1, 1, 1]}),
1185
)
1186
1187
capture = capfd.readouterr().err
1188
1189
# Ensure this only happens once.
1190
assert (
1191
capture.count(
1192
"initialize_scan_predicate: Predicate pushdown allows skipping 1 / 2 files"
1193
)
1194
== 1
1195
)
1196
1197
plan = pl.scan_parquet(tmp_path).filter(pl.col("a") < 0).explain()
1198
assert plan.startswith("Parquet SCAN []")
1199
1200
assert_frame_equal(
1201
pl.scan_parquet(tmp_path).with_row_index().filter(pl.col("a") == 2).collect(),
1202
pl.DataFrame(
1203
{"index": [5, 6, 7, 8, 9], "x": [5, 6, 7, 8, 9], "a": [2, 2, 2, 2, 2]},
1204
schema_overrides={"index": pl.get_index_type()},
1205
),
1206
)
1207
1208
assert_frame_equal(
1209
pl.scan_parquet(tmp_path).tail(1).filter(pl.col("a") == 1).collect(),
1210
pl.DataFrame(schema={"x": pl.Int64, "a": pl.Int64}),
1211
)
1212
1213