Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_hive.py
6939 views
1
from __future__ import annotations
2
3
import sys
4
import urllib.parse
5
import warnings
6
from collections import OrderedDict
7
from datetime import date, datetime
8
from functools import partial
9
from pathlib import Path
10
from typing import Any, Callable
11
12
import pyarrow.parquet as pq
13
import pytest
14
15
import polars as pl
16
from polars.exceptions import ComputeError, SchemaFieldNotFoundError
17
from polars.testing import assert_frame_equal, assert_series_equal
18
19
20
def impl_test_hive_partitioned_predicate_pushdown(
21
io_files_path: Path,
22
tmp_path: Path,
23
monkeypatch: Any,
24
) -> None:
25
monkeypatch.setenv("POLARS_VERBOSE", "1")
26
df = pl.read_ipc(io_files_path / "*.ipc")
27
28
root = tmp_path / "partitioned_data"
29
30
pq.write_to_dataset(
31
df.to_arrow(),
32
root_path=root,
33
partition_cols=["category", "fats_g"],
34
)
35
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=False)
36
# checks schema
37
assert q.collect_schema().names() == ["calories", "sugars_g"]
38
# checks materialization
39
assert q.collect().columns == ["calories", "sugars_g"]
40
41
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
42
assert q.collect_schema().names() == ["calories", "sugars_g", "category", "fats_g"]
43
44
# Partitioning changes the order
45
sort_by = ["fats_g", "category", "calories", "sugars_g"]
46
47
# The hive partitioned columns are appended,
48
# so we must ensure we assert in the proper order.
49
df = df.select(["calories", "sugars_g", "category", "fats_g"])
50
for streaming in [True, False]:
51
for pred in [
52
pl.col("category") == "vegetables",
53
pl.col("category") != "vegetables",
54
pl.col("fats_g") > 0.5,
55
(pl.col("fats_g") == 0.5) & (pl.col("category") == "vegetables"),
56
]:
57
assert_frame_equal(
58
q.filter(pred)
59
.sort(sort_by)
60
.collect(engine="streaming" if streaming else "in-memory"),
61
df.filter(pred).sort(sort_by),
62
)
63
64
# tests: 11536
65
assert q.filter(pl.col("sugars_g") == 25).collect().shape == (1, 4)
66
67
# tests: 12570
68
assert q.filter(pl.col("fats_g") == 1225.0).select("category").collect().shape == (
69
0,
70
1,
71
)
72
73
74
@pytest.mark.xdist_group("streaming")
75
@pytest.mark.write_disk
76
def test_hive_partitioned_predicate_pushdown(
77
io_files_path: Path,
78
tmp_path: Path,
79
monkeypatch: Any,
80
) -> None:
81
impl_test_hive_partitioned_predicate_pushdown(
82
io_files_path,
83
tmp_path,
84
monkeypatch,
85
)
86
87
88
@pytest.mark.xdist_group("streaming")
89
@pytest.mark.write_disk
90
def test_hive_partitioned_predicate_pushdown_single_threaded_async_17155(
91
io_files_path: Path,
92
tmp_path: Path,
93
monkeypatch: Any,
94
) -> None:
95
monkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
96
monkeypatch.setenv("POLARS_PREFETCH_SIZE", "1")
97
98
impl_test_hive_partitioned_predicate_pushdown(
99
io_files_path,
100
tmp_path,
101
monkeypatch,
102
)
103
104
105
@pytest.mark.write_disk
106
@pytest.mark.may_fail_auto_streaming
107
@pytest.mark.may_fail_cloud # reason: inspects logs
108
def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files(
109
tmp_path: Path, monkeypatch: Any, capfd: Any
110
) -> None:
111
monkeypatch.setenv("POLARS_VERBOSE", "1")
112
df = pl.DataFrame({"d": pl.arange(0, 5, eager=True)}).with_columns(
113
a=pl.col("d") % 5
114
)
115
root = tmp_path / "test_int_partitions"
116
df.write_parquet(
117
root,
118
use_pyarrow=True,
119
pyarrow_options={"partition_cols": ["a"]},
120
)
121
122
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
123
assert q.filter(pl.col("a").is_in([1, 4])).collect().shape == (2, 2)
124
assert "allows skipping 3 / 5" in capfd.readouterr().err
125
126
# Ensure the CSE can work with hive partitions.
127
q = q.filter(pl.col("a").gt(2))
128
result = q.join(q, on="a", how="left").collect(
129
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
130
)
131
expected = {
132
"a": [3, 4],
133
"d": [3, 4],
134
"d_right": [3, 4],
135
}
136
assert result.to_dict(as_series=False) == expected
137
138
139
@pytest.mark.write_disk
140
def test_hive_streaming_pushdown_is_in_22212(tmp_path: Path) -> None:
141
(
142
pl.DataFrame({"x": range(5)}).write_parquet(
143
tmp_path,
144
partition_by="x",
145
)
146
)
147
148
lf = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(
149
pl.col("x").is_in([1, 4])
150
)
151
152
assert_frame_equal(
153
lf.collect(
154
engine="streaming", optimizations=pl.QueryOptFlags(predicate_pushdown=False)
155
),
156
lf.collect(
157
engine="streaming", optimizations=pl.QueryOptFlags(predicate_pushdown=True)
158
),
159
)
160
161
162
@pytest.mark.xdist_group("streaming")
163
@pytest.mark.write_disk
164
@pytest.mark.parametrize("streaming", [True, False])
165
def test_hive_partitioned_slice_pushdown(
166
io_files_path: Path, tmp_path: Path, streaming: bool
167
) -> None:
168
df = pl.read_ipc(io_files_path / "*.ipc")
169
170
root = tmp_path / "partitioned_data"
171
172
# Ignore the pyarrow legacy warning until we can write properly with new settings.
173
warnings.filterwarnings("ignore")
174
pq.write_to_dataset(
175
df.to_arrow(),
176
root_path=root,
177
partition_cols=["category", "fats_g"],
178
)
179
180
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
181
schema = q.collect_schema()
182
expect_count = pl.select(pl.lit(1, dtype=pl.UInt32).alias(x) for x in schema)
183
184
assert_frame_equal(
185
q.head(1)
186
.collect(engine="streaming" if streaming else "in-memory")
187
.select(pl.all().len()),
188
expect_count,
189
)
190
assert q.head(0).collect(
191
engine="streaming" if streaming else "in-memory"
192
).columns == [
193
"calories",
194
"sugars_g",
195
"category",
196
"fats_g",
197
]
198
199
200
@pytest.mark.xdist_group("streaming")
201
@pytest.mark.write_disk
202
def test_hive_partitioned_projection_pushdown(
203
io_files_path: Path, tmp_path: Path
204
) -> None:
205
df = pl.read_ipc(io_files_path / "*.ipc")
206
207
root = tmp_path / "partitioned_data"
208
209
# Ignore the pyarrow legacy warning until we can write properly with new settings.
210
warnings.filterwarnings("ignore")
211
pq.write_to_dataset(
212
df.to_arrow(),
213
root_path=root,
214
partition_cols=["category", "fats_g"],
215
)
216
217
q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True)
218
columns = ["sugars_g", "category"]
219
for streaming in [True, False]:
220
assert (
221
q.select(columns)
222
.collect(engine="streaming" if streaming else "in-memory")
223
.columns
224
== columns
225
)
226
227
# test that hive partition columns are projected with the correct height when
228
# the projection contains only hive partition columns (11796)
229
for parallel in ("row_groups", "columns"):
230
q = pl.scan_parquet(
231
root / "**/*.parquet",
232
hive_partitioning=True,
233
parallel=parallel,
234
)
235
236
expected = q.collect().select("category")
237
result = q.select("category").collect()
238
239
assert_frame_equal(result, expected)
240
241
242
@pytest.mark.write_disk
243
def test_hive_partitioned_projection_skips_files(tmp_path: Path) -> None:
244
# ensure that it makes hive columns even when . in dir value
245
# and that it doesn't make hive columns from filename with =
246
df = pl.DataFrame(
247
{"sqlver": [10012.0, 10013.0], "namespace": ["eos", "fda"], "a": [1, 2]}
248
)
249
root = tmp_path / "partitioned_data"
250
for dir_tuple, sub_df in df.partition_by(
251
["sqlver", "namespace"], include_key=False, as_dict=True
252
).items():
253
new_path = root / f"sqlver={dir_tuple[0]}" / f"namespace={dir_tuple[1]}"
254
new_path.mkdir(parents=True, exist_ok=True)
255
sub_df.write_parquet(new_path / "file=8484.parquet")
256
test_df = (
257
pl.scan_parquet(str(root) + "/**/**/*.parquet", hive_partitioning=True)
258
# don't care about column order
259
.select("sqlver", "namespace", "a", pl.exclude("sqlver", "namespace", "a"))
260
.collect()
261
)
262
assert_frame_equal(df, test_df)
263
264
265
@pytest.fixture
266
def dataset_path(tmp_path: Path) -> Path:
267
tmp_path.mkdir(exist_ok=True)
268
269
# Set up Hive partitioned Parquet file
270
root = tmp_path / "dataset"
271
part1 = root / "c=1"
272
part2 = root / "c=2"
273
root.mkdir()
274
part1.mkdir()
275
part2.mkdir()
276
df1 = pl.DataFrame({"a": [1, 2], "b": [11.0, 12.0]})
277
df2 = pl.DataFrame({"a": [3, 4], "b": [13.0, 14.0]})
278
df3 = pl.DataFrame({"a": [5, 6], "b": [15.0, 16.0]})
279
df1.write_parquet(part1 / "one.parquet")
280
df2.write_parquet(part1 / "two.parquet")
281
df3.write_parquet(part2 / "three.parquet")
282
283
return root
284
285
286
@pytest.mark.write_disk
287
def test_scan_parquet_hive_schema(dataset_path: Path) -> None:
288
result = pl.scan_parquet(dataset_path / "**/*.parquet", hive_partitioning=True)
289
assert result.collect_schema() == OrderedDict(
290
{"a": pl.Int64, "b": pl.Float64, "c": pl.Int64}
291
)
292
293
result = pl.scan_parquet(
294
dataset_path / "**/*.parquet",
295
hive_partitioning=True,
296
hive_schema={"c": pl.Int32},
297
)
298
299
expected_schema = OrderedDict({"a": pl.Int64, "b": pl.Float64, "c": pl.Int32})
300
assert result.collect_schema() == expected_schema
301
assert result.collect().schema == expected_schema
302
303
304
@pytest.mark.write_disk
305
def test_read_parquet_invalid_hive_schema(dataset_path: Path) -> None:
306
with pytest.raises(
307
SchemaFieldNotFoundError,
308
match='path contains column not present in the given Hive schema: "c"',
309
):
310
pl.read_parquet(
311
dataset_path / "**/*.parquet",
312
hive_partitioning=True,
313
hive_schema={"nonexistent": pl.Int32},
314
)
315
316
317
def test_read_parquet_hive_schema_with_pyarrow() -> None:
318
with pytest.raises(
319
TypeError,
320
match="cannot use `hive_partitions` with `use_pyarrow=True`",
321
):
322
pl.read_parquet("test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True)
323
324
325
@pytest.mark.parametrize(
326
("scan_func", "write_func"),
327
[
328
(pl.scan_parquet, pl.DataFrame.write_parquet),
329
(pl.scan_ipc, pl.DataFrame.write_ipc),
330
],
331
)
332
@pytest.mark.parametrize(
333
"glob",
334
[True, False],
335
)
336
def test_hive_partition_directory_scan(
337
tmp_path: Path,
338
scan_func: Callable[..., pl.LazyFrame],
339
write_func: Callable[[pl.DataFrame, Path], None],
340
glob: bool,
341
) -> None:
342
tmp_path.mkdir(exist_ok=True)
343
344
dfs = [
345
pl.DataFrame({'x': 5 * [1], 'a': 1, 'b': 1}),
346
pl.DataFrame({'x': 5 * [2], 'a': 1, 'b': 2}),
347
pl.DataFrame({'x': 5 * [3], 'a': 22, 'b': 1}),
348
pl.DataFrame({'x': 5 * [4], 'a': 22, 'b': 2}),
349
] # fmt: skip
350
351
for df in dfs:
352
a = df.item(0, "a")
353
b = df.item(0, "b")
354
path = tmp_path / f"a={a}/b={b}/data.bin"
355
path.parent.mkdir(exist_ok=True, parents=True)
356
write_func(df.drop("a", "b"), path)
357
358
df = pl.concat(dfs)
359
hive_schema = df.lazy().select("a", "b").collect_schema()
360
361
scan = scan_func
362
363
if scan_func is pl.scan_parquet:
364
scan = partial(scan, glob=glob)
365
366
scan_with_hive_schema = partial(scan_func, hive_schema=hive_schema)
367
368
out = scan_with_hive_schema(
369
tmp_path,
370
hive_partitioning=True,
371
).collect()
372
assert_frame_equal(out, df)
373
374
out = scan(tmp_path, hive_partitioning=False).collect()
375
assert_frame_equal(out, df.drop("a", "b"))
376
377
out = scan_with_hive_schema(
378
tmp_path / "a=1",
379
hive_partitioning=True,
380
).collect()
381
assert_frame_equal(out, df.filter(a=1).drop("a"))
382
383
out = scan(tmp_path / "a=1", hive_partitioning=False).collect()
384
assert_frame_equal(out, df.filter(a=1).drop("a", "b"))
385
386
path = tmp_path / "a=1/b=1/data.bin"
387
388
out = scan_with_hive_schema(path, hive_partitioning=True).collect()
389
assert_frame_equal(out, dfs[0])
390
391
out = scan(path, hive_partitioning=False).collect()
392
assert_frame_equal(out, dfs[0].drop("a", "b"))
393
394
# Test default behavior with `hive_partitioning=None`, which should only
395
# enable hive partitioning when a single directory is passed:
396
out = scan_with_hive_schema(tmp_path).collect()
397
assert_frame_equal(out, df)
398
399
# Otherwise, hive partitioning is not enabled automatically:
400
out = scan(tmp_path / "a=1/b=1/data.bin").collect()
401
assert out.columns == ["x"]
402
403
out = scan([tmp_path / "a=1/", tmp_path / "a=22/"]).collect()
404
assert out.columns == ["x"]
405
406
out = scan([tmp_path / "a=1/", tmp_path / "a=22/b=1/data.bin"]).collect()
407
assert out.columns == ["x"]
408
409
if glob:
410
out = scan(tmp_path / "a=1/**/*.bin").collect()
411
assert out.columns == ["x"]
412
413
# Test `hive_partitioning=True`
414
out = scan_with_hive_schema(tmp_path, hive_partitioning=True).collect()
415
assert_frame_equal(out, df)
416
417
# Accept multiple directories from the same level
418
out = scan_with_hive_schema(
419
[tmp_path / "a=1", tmp_path / "a=22"], hive_partitioning=True
420
).collect()
421
assert_frame_equal(out, df.drop("a"))
422
423
with pytest.raises(
424
pl.exceptions.InvalidOperationError,
425
match="attempted to read from different directory levels with hive partitioning enabled:",
426
):
427
scan_with_hive_schema(
428
[tmp_path / "a=1", tmp_path / "a=22/b=1"], hive_partitioning=True
429
).collect()
430
431
if glob:
432
out = scan_with_hive_schema(
433
tmp_path / "**/*.bin", hive_partitioning=True
434
).collect()
435
assert_frame_equal(out, df)
436
437
# Parse hive from full path for glob patterns
438
out = scan_with_hive_schema(
439
[tmp_path / "a=1/**/*.bin", tmp_path / "a=22/**/*.bin"],
440
hive_partitioning=True,
441
).collect()
442
assert_frame_equal(out, df)
443
444
# Parse hive from full path for files
445
out = scan_with_hive_schema(
446
tmp_path / "a=1/b=1/data.bin", hive_partitioning=True
447
).collect()
448
assert_frame_equal(out, df.filter(a=1, b=1))
449
450
out = scan_with_hive_schema(
451
[tmp_path / "a=1/b=1/data.bin", tmp_path / "a=22/b=1/data.bin"],
452
hive_partitioning=True,
453
).collect()
454
assert_frame_equal(
455
out,
456
df.filter(
457
((pl.col("a") == 1) & (pl.col("b") == 1))
458
| ((pl.col("a") == 22) & (pl.col("b") == 1))
459
),
460
)
461
462
# Test `hive_partitioning=False`
463
out = scan(tmp_path, hive_partitioning=False).collect()
464
assert_frame_equal(out, df.drop("a", "b"))
465
466
if glob:
467
out = scan(tmp_path / "**/*.bin", hive_partitioning=False).collect()
468
assert_frame_equal(out, df.drop("a", "b"))
469
470
out = scan(tmp_path / "a=1/b=1/data.bin", hive_partitioning=False).collect()
471
assert_frame_equal(out, df.filter(a=1, b=1).drop("a", "b"))
472
473
474
def test_hive_partition_schema_inference(tmp_path: Path) -> None:
475
tmp_path.mkdir(exist_ok=True)
476
477
dfs = [
478
pl.DataFrame({"x": 1}),
479
pl.DataFrame({"x": 2}),
480
pl.DataFrame({"x": 3}),
481
]
482
483
paths = [
484
tmp_path / "a=1/data.bin",
485
tmp_path / "a=1.5/data.bin",
486
tmp_path / "a=polars/data.bin",
487
]
488
489
expected = [
490
pl.Series("a", [1], dtype=pl.Int64),
491
pl.Series("a", [1.0, 1.5], dtype=pl.Float64),
492
pl.Series("a", ["1", "1.5", "polars"], dtype=pl.String),
493
]
494
495
for i in range(3):
496
paths[i].parent.mkdir(exist_ok=True, parents=True)
497
dfs[i].write_parquet(paths[i])
498
out = pl.scan_parquet(tmp_path).collect()
499
500
assert_series_equal(out["a"], expected[i])
501
502
503
@pytest.mark.write_disk
504
def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> None:
505
monkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
506
monkeypatch.setenv("POLARS_PREFETCH_SIZE", "1")
507
508
dfs = [
509
pl.DataFrame({"x": 1}),
510
pl.DataFrame({"x": 2}),
511
pl.DataFrame({"x": 3}),
512
]
513
514
paths = [
515
tmp_path / "a=1/b=1/data.bin",
516
tmp_path / "a=2/b=2/data.bin",
517
tmp_path / "a=3/b=3/data.bin",
518
]
519
520
for i in range(3):
521
paths[i].parent.mkdir(exist_ok=True, parents=True)
522
dfs[i].write_parquet(paths[i])
523
524
lf = pl.scan_parquet(tmp_path)
525
526
assert_frame_equal(
527
lf.collect(), pl.DataFrame({k: [1, 2, 3] for k in ["x", "a", "b"]})
528
)
529
530
531
@pytest.mark.parametrize(
532
("scan_func", "write_func"),
533
[
534
(partial(pl.scan_parquet, parallel="row_groups"), pl.DataFrame.write_parquet),
535
(partial(pl.scan_parquet, parallel="columns"), pl.DataFrame.write_parquet),
536
(partial(pl.scan_parquet, parallel="prefiltered"), pl.DataFrame.write_parquet),
537
(
538
lambda *a, **kw: pl.scan_parquet(*a, parallel="prefiltered", **kw).filter(
539
pl.col("b") == pl.col("b")
540
),
541
pl.DataFrame.write_parquet,
542
),
543
(pl.scan_ipc, pl.DataFrame.write_ipc),
544
],
545
)
546
@pytest.mark.write_disk
547
@pytest.mark.slow
548
@pytest.mark.parametrize("projection_pushdown", [True, False])
549
def test_hive_partition_columns_contained_in_file(
550
tmp_path: Path,
551
scan_func: Callable[[Any], pl.LazyFrame],
552
write_func: Callable[[pl.DataFrame, Path], None],
553
projection_pushdown: bool,
554
) -> None:
555
path = tmp_path / "a=1/b=2/data.bin"
556
path.parent.mkdir(exist_ok=True, parents=True)
557
df = pl.DataFrame(
558
{"x": 1, "a": 1, "b": 2, "y": 1},
559
schema={"x": pl.Int32, "a": pl.Int8, "b": pl.Int16, "y": pl.Int32},
560
)
561
write_func(df, path)
562
563
def assert_with_projections(
564
lf: pl.LazyFrame, df: pl.DataFrame, *, row_index: str | None = None
565
) -> None:
566
row_index: list[str] = [row_index] if row_index is not None else [] # type: ignore[no-redef]
567
568
from itertools import permutations
569
570
cols = ["a", "b", "x", "y", *row_index] # type: ignore[misc]
571
572
for projection in (
573
x for i in range(len(cols)) for x in permutations(cols[: 1 + i])
574
):
575
assert_frame_equal(
576
lf.select(projection).collect(
577
optimizations=pl.QueryOptFlags(
578
projection_pushdown=projection_pushdown
579
)
580
),
581
df.select(projection),
582
)
583
584
lf = scan_func(path, hive_partitioning=True) # type: ignore[call-arg]
585
rhs = df
586
assert_frame_equal(
587
lf.collect(
588
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
589
),
590
rhs,
591
)
592
assert_with_projections(lf, rhs)
593
594
lf = scan_func( # type: ignore[call-arg]
595
path,
596
hive_schema={"a": pl.String, "b": pl.String},
597
hive_partitioning=True,
598
)
599
rhs = df.with_columns(pl.col("a", "b").cast(pl.String))
600
assert_frame_equal(
601
lf.collect(
602
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
603
),
604
rhs,
605
)
606
assert_with_projections(lf, rhs)
607
608
# partial cols in file
609
partial_path = tmp_path / "a=1/b=2/partial_data.bin"
610
df = pl.DataFrame(
611
{"x": 1, "b": 2, "y": 1},
612
schema={"x": pl.Int32, "b": pl.Int16, "y": pl.Int32},
613
)
614
write_func(df, partial_path)
615
616
rhs = rhs.select(
617
pl.col("x").cast(pl.Int32),
618
pl.col("b").cast(pl.Int16),
619
pl.col("y").cast(pl.Int32),
620
pl.col("a").cast(pl.Int64),
621
)
622
623
lf = scan_func(partial_path, hive_partitioning=True) # type: ignore[call-arg]
624
assert_frame_equal(
625
lf.collect(
626
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
627
),
628
rhs,
629
)
630
assert_with_projections(lf, rhs)
631
632
assert_frame_equal(
633
lf.with_row_index().collect(
634
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
635
),
636
rhs.with_row_index(),
637
)
638
assert_with_projections(
639
lf.with_row_index(), rhs.with_row_index(), row_index="index"
640
)
641
642
assert_frame_equal(
643
lf.with_row_index()
644
.select(pl.exclude("index"), "index")
645
.collect(
646
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
647
),
648
rhs.with_row_index().select(pl.exclude("index"), "index"),
649
)
650
assert_with_projections(
651
lf.with_row_index().select(pl.exclude("index"), "index"),
652
rhs.with_row_index().select(pl.exclude("index"), "index"),
653
row_index="index",
654
)
655
656
lf = scan_func( # type: ignore[call-arg]
657
partial_path,
658
hive_schema={"a": pl.String, "b": pl.String},
659
hive_partitioning=True,
660
)
661
rhs = rhs.select(
662
pl.col("x").cast(pl.Int32),
663
pl.col("b").cast(pl.String),
664
pl.col("y").cast(pl.Int32),
665
pl.col("a").cast(pl.String),
666
)
667
assert_frame_equal(
668
lf.collect(
669
optimizations=pl.QueryOptFlags(projection_pushdown=projection_pushdown)
670
),
671
rhs,
672
)
673
assert_with_projections(lf, rhs)
674
675
676
@pytest.mark.write_disk
677
def test_hive_partition_dates(tmp_path: Path) -> None:
678
df = pl.DataFrame(
679
{
680
"date1": [
681
datetime(2024, 1, 1),
682
datetime(2024, 2, 1),
683
datetime(2024, 3, 1),
684
None,
685
],
686
"date2": [
687
datetime(2023, 1, 1),
688
datetime(2023, 2, 1),
689
None,
690
datetime(2023, 3, 1),
691
],
692
"x": [1, 2, 3, 4],
693
},
694
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
695
)
696
697
root = tmp_path / "pyarrow"
698
pq.write_to_dataset(
699
df.to_arrow(),
700
root_path=root,
701
partition_cols=["date1", "date2"],
702
)
703
704
lf = pl.scan_parquet(
705
root, hive_schema=df.clear().select("date1", "date2").collect_schema()
706
)
707
assert_frame_equal(lf.collect(), df.select("x", "date1", "date2"))
708
709
lf = pl.scan_parquet(root)
710
assert_frame_equal(lf.collect(), df.select("x", "date1", "date2"))
711
712
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
713
assert_frame_equal(
714
lf.collect(),
715
df.select("x", "date1", "date2").with_columns(
716
pl.col("date1", "date2").cast(pl.String)
717
),
718
)
719
720
for perc_escape in [True, False] if sys.platform != "win32" else [True]:
721
root = tmp_path / f"includes_hive_cols_in_file_{perc_escape}"
722
for (date1, date2), part_df in df.group_by(
723
pl.col("date1").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"),
724
pl.col("date2").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"),
725
):
726
if perc_escape:
727
date2 = urllib.parse.quote(date2)
728
729
path = root / f"date1={date1}/date2={date2}/data.bin"
730
path.parent.mkdir(exist_ok=True, parents=True)
731
part_df.write_parquet(path)
732
733
# The schema for the hive columns is included in the file, so it should
734
# just work
735
lf = pl.scan_parquet(root)
736
assert_frame_equal(lf.collect(), df)
737
738
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
739
assert_frame_equal(
740
lf.collect(),
741
df.with_columns(pl.col("date1", "date2").cast(pl.String)),
742
)
743
744
745
@pytest.mark.write_disk
746
def test_hive_partition_filter_null_23005(tmp_path: Path) -> None:
747
root = tmp_path
748
749
df = pl.DataFrame(
750
{
751
"date1": [
752
datetime(2024, 1, 1),
753
datetime(2024, 2, 1),
754
datetime(2024, 3, 1),
755
None,
756
],
757
"date2": [
758
datetime(2023, 1, 1),
759
datetime(2023, 2, 1),
760
None,
761
datetime(2023, 3, 1),
762
],
763
"x": [1, 2, 3, 4],
764
},
765
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
766
)
767
768
df.write_parquet(root, partition_by=["date1", "date2"])
769
770
q = pl.scan_parquet(root, include_file_paths="path")
771
772
full = q.collect()
773
774
assert (
775
full.select(
776
(
777
pl.any_horizontal(pl.col("date1", "date2").is_null())
778
& pl.col("path").str.contains("__HIVE_DEFAULT_PARTITION__")
779
).sum()
780
).item()
781
== 2
782
)
783
784
lf = pl.scan_parquet(root).filter(pl.col("date1") == datetime(2024, 1, 1))
785
assert_frame_equal(lf.collect(), df.head(1))
786
787
788
@pytest.mark.parametrize(
789
("scan_func", "write_func"),
790
[
791
(pl.scan_parquet, pl.DataFrame.write_parquet),
792
(pl.scan_ipc, pl.DataFrame.write_ipc),
793
],
794
)
795
@pytest.mark.write_disk
796
def test_projection_only_hive_parts_gives_correct_number_of_rows(
797
tmp_path: Path,
798
scan_func: Callable[[Any], pl.LazyFrame],
799
write_func: Callable[[pl.DataFrame, Path], None],
800
) -> None:
801
# Check the number of rows projected when projecting only hive parts, which
802
# should be the same as the number of rows in the file.
803
path = tmp_path / "a=3/data.bin"
804
path.parent.mkdir(exist_ok=True, parents=True)
805
806
write_func(pl.DataFrame({"x": [1, 1, 1]}), path)
807
808
assert_frame_equal(
809
scan_func(path, hive_partitioning=True).select("a").collect(), # type: ignore[call-arg]
810
pl.DataFrame({"a": [3, 3, 3]}),
811
)
812
813
814
@pytest.mark.parametrize(
815
"df",
816
[
817
pl.select(
818
pl.Series("a", [1, 2, 3, 4], dtype=pl.Int8),
819
pl.Series("b", [1, 2, 3, 4], dtype=pl.Int8),
820
pl.Series("x", [1, 2, 3, 4]),
821
),
822
pl.select(
823
pl.Series(
824
"a",
825
[1.2981275, 2.385974035, 3.1231892749185718397510, 4.129387128949156],
826
dtype=pl.Float64,
827
),
828
pl.Series("b", ["a", "b", " / c = : ", "d"]),
829
pl.Series("x", [1, 2, 3, 4]),
830
),
831
],
832
)
833
@pytest.mark.write_disk
834
def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None:
835
root = tmp_path
836
df.write_parquet(root, partition_by=["a", "b"])
837
838
lf = pl.scan_parquet(root)
839
assert_frame_equal(lf.collect(), df)
840
841
lf = pl.scan_parquet(root, hive_schema={"a": pl.String, "b": pl.String})
842
assert_frame_equal(lf.collect(), df.with_columns(pl.col("a", "b").cast(pl.String)))
843
844
845
@pytest.mark.slow
846
@pytest.mark.write_disk
847
@pytest.mark.xfail
848
def test_hive_write_multiple_files(tmp_path: Path) -> None:
849
chunk_size = 262_144
850
n_rows = 100_000
851
df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))
852
853
n_files = int(df.estimated_size() / chunk_size)
854
855
assert n_files > 1, "increase df size or decrease file size"
856
857
root = tmp_path
858
df.write_parquet(root, partition_by="a", partition_chunk_size_bytes=chunk_size)
859
860
assert sum(1 for _ in (root / "a=0").iterdir()) == n_files
861
assert_frame_equal(pl.scan_parquet(root).collect(), df)
862
863
864
@pytest.mark.write_disk
865
def test_hive_write_dates(tmp_path: Path) -> None:
866
df = pl.DataFrame(
867
{
868
"date1": [
869
datetime(2024, 1, 1),
870
datetime(2024, 2, 1),
871
datetime(2024, 3, 1),
872
None,
873
],
874
"date2": [
875
datetime(2023, 1, 1),
876
datetime(2023, 2, 1),
877
None,
878
datetime(2023, 3, 1, 1, 1, 1, 1),
879
],
880
"x": [1, 2, 3, 4],
881
},
882
schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32},
883
)
884
885
root = tmp_path
886
df.write_parquet(root, partition_by=["date1", "date2"])
887
888
lf = pl.scan_parquet(root)
889
assert_frame_equal(lf.collect(), df)
890
891
lf = pl.scan_parquet(root, try_parse_hive_dates=False)
892
assert_frame_equal(
893
lf.collect(),
894
df.with_columns(pl.col("date1", "date2").cast(pl.String)),
895
)
896
897
898
@pytest.mark.write_disk
899
@pytest.mark.may_fail_auto_streaming
900
@pytest.mark.may_fail_cloud # reason: inspects logs
901
def test_hive_predicate_dates_14712(
902
tmp_path: Path, monkeypatch: Any, capfd: Any
903
) -> None:
904
monkeypatch.setenv("POLARS_VERBOSE", "1")
905
pl.DataFrame({"a": [datetime(2024, 1, 1)]}).write_parquet(
906
tmp_path, partition_by="a"
907
)
908
pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect()
909
assert "allows skipping 1 / 1" in capfd.readouterr().err
910
911
912
@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths")
913
@pytest.mark.write_disk
914
def test_hive_windows_splits_on_forward_slashes(tmp_path: Path) -> None:
915
# Note: This needs to be an absolute path.
916
tmp_path = tmp_path.resolve()
917
path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1"
918
Path(path).mkdir(exist_ok=True, parents=True)
919
920
df = pl.DataFrame({"x": "x"})
921
df.write_parquet(f"{path}/data.parquet")
922
923
expect = pl.DataFrame(
924
[
925
s.new_from_index(0, 5)
926
for s in pl.DataFrame(
927
{
928
"x": "x",
929
"a": 1,
930
"b": 1,
931
"c": 1,
932
"d": 1,
933
"e": 1,
934
}
935
)
936
]
937
)
938
939
assert_frame_equal(
940
pl.scan_parquet(
941
[
942
f"{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet",
943
f"{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet",
944
f"{tmp_path}\\a=1/b=1/c=1/d=1/**/*",
945
f"{tmp_path}/a=1/b=1\\c=1/d=1/**/*",
946
f"{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*",
947
],
948
hive_partitioning=True,
949
).collect(),
950
expect,
951
)
952
953
954
@pytest.mark.write_disk
955
def test_passing_hive_schema_with_hive_partitioning_disabled_raises(
956
tmp_path: Path,
957
) -> None:
958
with pytest.raises(
959
ComputeError,
960
match="a hive schema was given but hive_partitioning was disabled",
961
):
962
pl.scan_parquet(
963
tmp_path,
964
schema={"x": pl.Int64},
965
hive_schema={"h": pl.String},
966
hive_partitioning=False,
967
).collect()
968
969
970
@pytest.mark.write_disk
971
def test_hive_auto_enables_when_unspecified_and_hive_schema_passed(
972
tmp_path: Path,
973
) -> None:
974
tmp_path.mkdir(exist_ok=True)
975
(tmp_path / "a=1").mkdir(exist_ok=True)
976
977
pl.DataFrame({"x": 1}).write_parquet(tmp_path / "a=1/1")
978
979
for path in [tmp_path / "a=1/1", tmp_path / "**/*"]:
980
lf = pl.scan_parquet(path, hive_schema={"a": pl.UInt8})
981
982
assert_frame_equal(
983
lf.collect(),
984
pl.select(
985
pl.Series("x", [1]),
986
pl.Series("a", [1], dtype=pl.UInt8),
987
),
988
)
989
990
991
@pytest.mark.write_disk
992
def test_hive_file_as_uri_with_hive_start_idx_23830(
993
tmp_path: Path,
994
) -> None:
995
tmp_path.mkdir(exist_ok=True)
996
(tmp_path / "a=1").mkdir(exist_ok=True)
997
998
pl.DataFrame({"x": 1}).write_parquet(tmp_path / "a=1/1")
999
1000
# ensure we have a trailing "/"
1001
uri = tmp_path.resolve().as_posix().rstrip("/") + "/"
1002
uri = "file://" + uri
1003
1004
lf = pl.scan_parquet(uri, hive_schema={"a": pl.UInt8})
1005
1006
assert_frame_equal(
1007
lf.collect(),
1008
pl.select(
1009
pl.Series("x", [1]),
1010
pl.Series("a", [1], dtype=pl.UInt8),
1011
),
1012
)
1013
1014
1015
@pytest.mark.write_disk
1016
@pytest.mark.parametrize("force_single_thread", [True, False])
1017
def test_hive_parquet_prefiltered_20894_21327(
1018
tmp_path: Path, force_single_thread: bool
1019
) -> None:
1020
n_threads = 1 if force_single_thread else pl.thread_pool_size()
1021
1022
file_path = tmp_path / "date=2025-01-01/00000000.parquet"
1023
file_path.parent.mkdir(exist_ok=True, parents=True)
1024
1025
data = pl.DataFrame(
1026
{
1027
"date": [date(2025, 1, 1), date(2025, 1, 1)],
1028
"value": ["1", "2"],
1029
}
1030
)
1031
1032
data.write_parquet(file_path)
1033
1034
import base64
1035
import subprocess
1036
1037
# For security, and for Windows backslashes.
1038
scan_path_b64 = base64.b64encode(str(file_path.absolute()).encode()).decode()
1039
1040
# This is, the easiest way to control the threadpool size so that it is stable.
1041
out = subprocess.check_output(
1042
[
1043
sys.executable,
1044
"-c",
1045
f"""\
1046
import os
1047
os.environ["POLARS_MAX_THREADS"] = "{n_threads}"
1048
1049
import polars as pl
1050
import datetime
1051
import base64
1052
1053
from polars.testing import assert_frame_equal
1054
1055
assert pl.thread_pool_size() == {n_threads}
1056
1057
tmp_path = base64.b64decode("{scan_path_b64}").decode()
1058
df = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(pl.col("value") == "1").collect()
1059
# We need the str() to trigger panic on invalid state
1060
str(df)
1061
1062
assert_frame_equal(df, pl.DataFrame(
1063
[
1064
pl.Series('date', [datetime.date(2025, 1, 1)], dtype=pl.Date),
1065
pl.Series('value', ['1'], dtype=pl.String),
1066
]
1067
))
1068
1069
print("OK", end="")
1070
""",
1071
],
1072
)
1073
1074
assert out == b"OK"
1075
1076
1077
def test_hive_decode_reserved_ascii_23241(tmp_path: Path) -> None:
1078
partitioned_tbl_uri = (tmp_path / "partitioned_data").resolve()
1079
start, stop = 32, 127
1080
df = pl.DataFrame(
1081
{
1082
"a": list(range(start, stop)),
1083
"strings": [chr(i) for i in range(start, stop)],
1084
}
1085
)
1086
df.write_delta(partitioned_tbl_uri, delta_write_options={"partition_by": "strings"})
1087
out = pl.read_delta(str(partitioned_tbl_uri)).sort("a").select(pl.col("strings"))
1088
1089
assert_frame_equal(df.sort(by=pl.col("a")).select(pl.col("strings")), out)
1090
1091
1092
def test_hive_decode_utf8_23241(tmp_path: Path) -> None:
1093
df = pl.DataFrame(
1094
{
1095
"strings": [
1096
"Türkiye And Egpyt",
1097
"résumé père forêt Noël",
1098
"😊",
1099
"北极熊", # a polar bear perhaps ?!
1100
],
1101
"a": [10, 20, 30, 40],
1102
}
1103
)
1104
partitioned_tbl_uri = (tmp_path / "partitioned_data").resolve()
1105
df.write_delta(partitioned_tbl_uri, delta_write_options={"partition_by": "strings"})
1106
out = pl.read_delta(str(partitioned_tbl_uri)).sort("a").select(pl.col("strings"))
1107
1108
assert_frame_equal(df.sort(by=pl.col("a")).select(pl.col("strings")), out)
1109
1110
1111
@pytest.mark.write_disk
1112
def test_hive_filter_lit_true_24235(tmp_path: Path) -> None:
1113
df = pl.DataFrame({"p": [1, 2, 3, 4, 5], "x": [1, 1, 2, 2, 3]})
1114
df.lazy().sink_parquet(pl.PartitionByKey(tmp_path, by="p"), mkdir=True)
1115
1116
assert_frame_equal(
1117
pl.scan_parquet(tmp_path).filter(True).collect(),
1118
df,
1119
)
1120
1121
assert_frame_equal(
1122
pl.scan_parquet(tmp_path).filter(pl.lit(True)).collect(),
1123
df,
1124
)
1125
1126
assert_frame_equal(
1127
pl.scan_parquet(tmp_path).filter(False).collect(),
1128
df.clear(),
1129
)
1130
1131
assert_frame_equal(
1132
pl.scan_parquet(tmp_path).filter(pl.lit(False)).collect(),
1133
df.clear(),
1134
)
1135
1136