Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_multiscan.py
8422 views
1
from __future__ import annotations
2
3
import io
4
import re
5
import sys
6
from functools import partial
7
from typing import IO, TYPE_CHECKING, Any
8
9
import pyarrow.parquet as pq
10
import pytest
11
from hypothesis import given
12
from hypothesis import strategies as st
13
14
import polars as pl
15
from polars.meta.index_type import get_index_type
16
from polars.testing import assert_frame_equal
17
from tests.unit.io.conftest import normalize_path_separator_pl
18
19
if TYPE_CHECKING:
20
from collections.abc import Callable
21
from pathlib import Path
22
23
from tests.conftest import PlMonkeyPatch
24
25
SCAN_AND_WRITE_FUNCS = [
26
(pl.scan_ipc, pl.DataFrame.write_ipc),
27
(pl.scan_parquet, pl.DataFrame.write_parquet),
28
(pl.scan_csv, pl.DataFrame.write_csv),
29
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
30
]
31
32
33
@pytest.mark.write_disk
34
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
35
def test_include_file_paths(tmp_path: Path, scan: Any, write: Any) -> None:
36
a_path = tmp_path / "a"
37
b_path = tmp_path / "b"
38
39
write(pl.DataFrame({"a": [5, 10]}), a_path)
40
write(pl.DataFrame({"a": [1996]}), b_path)
41
42
out = scan([a_path, b_path], include_file_paths="f")
43
44
assert_frame_equal(
45
out.collect(),
46
pl.DataFrame(
47
{
48
"a": [5, 10, 1996],
49
"f": [str(a_path), str(a_path), str(b_path)],
50
}
51
).with_columns(normalize_path_separator_pl(pl.col("f"))),
52
)
53
54
55
@pytest.mark.parametrize(
56
("scan", "write", "ext", "supports_missing_columns", "supports_hive_partitioning"),
57
[
58
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc", False, True),
59
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet", True, True),
60
(pl.scan_csv, pl.DataFrame.write_csv, "csv", False, False),
61
(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl", False, False),
62
],
63
)
64
@pytest.mark.parametrize("missing_column", [False, True])
65
@pytest.mark.parametrize("row_index", [False, True])
66
@pytest.mark.parametrize("include_file_paths", [False, True])
67
@pytest.mark.parametrize("hive", [False, True])
68
@pytest.mark.parametrize("col", [False, True])
69
@pytest.mark.write_disk
70
def test_multiscan_projection(
71
tmp_path: Path,
72
scan: Callable[..., pl.LazyFrame],
73
write: Callable[[pl.DataFrame, Path], Any],
74
ext: str,
75
supports_missing_columns: bool,
76
supports_hive_partitioning: bool,
77
missing_column: bool,
78
row_index: bool,
79
include_file_paths: bool,
80
hive: bool,
81
col: bool,
82
) -> None:
83
a = pl.DataFrame({"col": [5, 10, 1996]})
84
b = pl.DataFrame({"col": [13, 37]})
85
86
if missing_column and supports_missing_columns:
87
a = a.with_columns(missing=pl.Series([420, 2000, 9]))
88
89
a_path: Path
90
b_path: Path
91
multiscan_path: Path
92
93
if hive and supports_hive_partitioning:
94
(tmp_path / "hive_col=0").mkdir()
95
a_path = tmp_path / "hive_col=0" / f"a.{ext}"
96
(tmp_path / "hive_col=1").mkdir()
97
b_path = tmp_path / "hive_col=1" / f"b.{ext}"
98
99
multiscan_path = tmp_path
100
101
else:
102
a_path = tmp_path / f"a.{ext}"
103
b_path = tmp_path / f"b.{ext}"
104
105
multiscan_path = tmp_path / f"*.{ext}"
106
107
write(a, a_path)
108
write(b, b_path)
109
110
base_projection = []
111
if missing_column and supports_missing_columns:
112
base_projection += ["missing"]
113
if row_index:
114
base_projection += ["row_index"]
115
if include_file_paths:
116
base_projection += ["file_path"]
117
if hive and supports_hive_partitioning:
118
base_projection += ["hive_col"]
119
if col:
120
base_projection += ["col"]
121
122
ifp = "file_path" if include_file_paths else None
123
ri = "row_index" if row_index else None
124
125
args = {
126
"missing_columns": "insert" if missing_column else "raise",
127
"include_file_paths": ifp,
128
"row_index_name": ri,
129
"hive_partitioning": hive,
130
}
131
132
if not supports_missing_columns:
133
del args["missing_columns"]
134
if not supports_hive_partitioning:
135
del args["hive_partitioning"]
136
137
for projection in [
138
base_projection,
139
base_projection[::-1],
140
]:
141
assert_frame_equal(
142
scan(multiscan_path, **args).collect(engine="streaming").select(projection),
143
scan(multiscan_path, **args).select(projection).collect(engine="streaming"),
144
)
145
146
for remove in range(len(base_projection)):
147
new_projection = base_projection.copy()
148
new_projection.pop(remove)
149
150
for projection in [
151
new_projection,
152
new_projection[::-1],
153
]:
154
assert_frame_equal(
155
scan(multiscan_path, **args)
156
.collect(engine="streaming")
157
.select(projection),
158
scan(multiscan_path, **args)
159
.select(projection)
160
.collect(engine="streaming"),
161
)
162
163
164
@pytest.mark.parametrize(
165
("scan", "write", "ext"),
166
[
167
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),
168
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),
169
],
170
)
171
@pytest.mark.write_disk
172
def test_multiscan_hive_predicate(
173
tmp_path: Path,
174
scan: Callable[..., pl.LazyFrame],
175
write: Callable[[pl.DataFrame, Path], Any],
176
ext: str,
177
) -> None:
178
a = pl.DataFrame({"col": [5, 10, 1996]})
179
b = pl.DataFrame({"col": [13, 37]})
180
c = pl.DataFrame({"col": [3, 5, 2024]})
181
182
(tmp_path / "hive_col=0").mkdir()
183
a_path = tmp_path / "hive_col=0" / f"0.{ext}"
184
(tmp_path / "hive_col=1").mkdir()
185
b_path = tmp_path / "hive_col=1" / f"0.{ext}"
186
(tmp_path / "hive_col=2").mkdir()
187
c_path = tmp_path / "hive_col=2" / f"0.{ext}"
188
189
multiscan_path = tmp_path
190
191
write(a, a_path)
192
write(b, b_path)
193
write(c, c_path)
194
195
full = scan(multiscan_path).collect(engine="streaming")
196
full_ri = full.with_row_index("ri", 42)
197
198
last_pred = None
199
try:
200
for pred in [
201
pl.col.hive_col == 0,
202
pl.col.hive_col == 1,
203
pl.col.hive_col == 2,
204
pl.col.hive_col < 2,
205
pl.col.hive_col > 0,
206
pl.col.hive_col != 1,
207
pl.col.hive_col != 3,
208
pl.col.col == 13,
209
pl.col.col != 13,
210
(pl.col.col != 13) & (pl.col.hive_col == 1),
211
(pl.col.col != 13) & (pl.col.hive_col != 1),
212
]:
213
last_pred = pred
214
assert_frame_equal(
215
full.filter(pred),
216
scan(multiscan_path).filter(pred).collect(engine="streaming"),
217
)
218
219
assert_frame_equal(
220
full_ri.filter(pred),
221
scan(multiscan_path)
222
.with_row_index("ri", 42)
223
.filter(pred)
224
.collect(engine="streaming"),
225
)
226
except Exception as _:
227
print(last_pred)
228
raise
229
230
231
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
232
@pytest.mark.write_disk
233
def test_multiscan_row_index(
234
tmp_path: Path,
235
scan: Callable[..., pl.LazyFrame],
236
write: Callable[[pl.DataFrame, Path], Any],
237
) -> None:
238
a = pl.DataFrame({"col": [5, 10, 1996]})
239
b = pl.DataFrame({"col": [42]})
240
c = pl.DataFrame({"col": [13, 37]})
241
242
write(a, tmp_path / "a")
243
write(b, tmp_path / "b")
244
write(c, tmp_path / "c")
245
246
col = pl.concat([a, b, c]).to_series()
247
g = tmp_path / "*"
248
249
assert_frame_equal(
250
scan(g, row_index_name="ri").collect(),
251
pl.DataFrame(
252
[
253
pl.Series("ri", range(6), get_index_type()),
254
col,
255
]
256
),
257
)
258
259
start = 42
260
assert_frame_equal(
261
scan(g, row_index_name="ri", row_index_offset=start).collect(),
262
pl.DataFrame(
263
[
264
pl.Series("ri", range(start, start + 6), get_index_type()),
265
col,
266
]
267
),
268
)
269
270
start = 42
271
assert_frame_equal(
272
scan(g, row_index_name="ri", row_index_offset=start).slice(3, 3).collect(),
273
pl.DataFrame(
274
[
275
pl.Series("ri", range(start + 3, start + 6), get_index_type()),
276
col.slice(3, 3),
277
]
278
),
279
)
280
281
start = 42
282
assert_frame_equal(
283
scan(g, row_index_name="ri", row_index_offset=start)
284
.filter(pl.col("col") < 15)
285
.collect(),
286
pl.DataFrame(
287
[
288
pl.Series("ri", [start + 0, start + 1, start + 4], get_index_type()),
289
pl.Series("col", [5, 10, 13]),
290
]
291
),
292
)
293
294
with pytest.raises(
295
pl.exceptions.DuplicateError, match="duplicate column name index"
296
):
297
scan(g).with_row_index().with_row_index().collect()
298
299
assert_frame_equal(
300
scan(g)
301
.with_row_index()
302
.with_row_index("index_1", offset=1)
303
.with_row_index("index_2", offset=2)
304
.collect(),
305
pl.DataFrame(
306
[
307
pl.Series("index_2", [2, 3, 4, 5, 6, 7], get_index_type()),
308
pl.Series("index_1", [1, 2, 3, 4, 5, 6], get_index_type()),
309
pl.Series("index", [0, 1, 2, 3, 4, 5], get_index_type()),
310
col,
311
]
312
),
313
)
314
315
316
@pytest.mark.parametrize(
317
("scan", "write", "ext"),
318
[
319
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),
320
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),
321
pytest.param(
322
pl.scan_csv,
323
pl.DataFrame.write_csv,
324
"csv",
325
marks=pytest.mark.xfail(
326
reason="See https://github.com/pola-rs/polars/issues/21211"
327
),
328
),
329
(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"),
330
],
331
)
332
@pytest.mark.write_disk
333
def test_schema_mismatch_type_mismatch(
334
tmp_path: Path,
335
scan: Callable[..., pl.LazyFrame],
336
write: Callable[[pl.DataFrame, Path], Any],
337
ext: str,
338
) -> None:
339
a = pl.DataFrame({"xyz_col": [5, 10, 1996]})
340
b = pl.DataFrame({"xyz_col": ["a", "b", "c"]})
341
342
a_path = tmp_path / f"a.{ext}"
343
b_path = tmp_path / f"b.{ext}"
344
345
multiscan_path = tmp_path / f"*.{ext}"
346
347
write(a, a_path)
348
write(b, b_path)
349
350
q = scan(multiscan_path)
351
352
# NDJSON will just parse according to `projected_schema`
353
cx = (
354
pytest.raises(
355
pl.exceptions.ComputeError,
356
match=re.escape("cannot parse 'a' (string) as Int64"),
357
)
358
if scan is pl.scan_ndjson
359
else pytest.raises(
360
pl.exceptions.SchemaError, # type: ignore[arg-type]
361
match=(
362
"data type mismatch for column xyz_col: "
363
"incoming: String != target: Int64"
364
),
365
)
366
)
367
368
with cx:
369
q.collect(engine="streaming")
370
371
372
@pytest.mark.parametrize(
373
("scan", "write", "ext"),
374
[
375
# (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), # TODO: _
376
# (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), # TODO: _
377
pytest.param(
378
pl.scan_csv,
379
pl.DataFrame.write_csv,
380
"csv",
381
marks=pytest.mark.xfail(
382
reason="See https://github.com/pola-rs/polars/issues/21211"
383
),
384
),
385
# (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), # TODO: _
386
],
387
)
388
@pytest.mark.write_disk
389
def test_schema_mismatch_order_mismatch(
390
tmp_path: Path,
391
scan: Callable[..., pl.LazyFrame],
392
write: Callable[[pl.DataFrame, Path], Any],
393
ext: str,
394
) -> None:
395
a = pl.DataFrame({"x": [5, 10, 1996], "y": ["a", "b", "c"]})
396
b = pl.DataFrame({"y": ["x", "y"], "x": [1, 2]})
397
398
a_path = tmp_path / f"a.{ext}"
399
b_path = tmp_path / f"b.{ext}"
400
401
multiscan_path = tmp_path / f"*.{ext}"
402
403
write(a, a_path)
404
write(b, b_path)
405
406
q = scan(multiscan_path)
407
408
with pytest.raises(pl.exceptions.SchemaError):
409
q.collect(engine="streaming")
410
411
412
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
413
def test_multiscan_head(
414
scan: Callable[..., pl.LazyFrame],
415
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
416
) -> None:
417
a = io.BytesIO()
418
b = io.BytesIO()
419
for f in [a, b]:
420
write(pl.Series("c1", range(10)).to_frame(), f)
421
f.seek(0)
422
423
assert_frame_equal(
424
scan([a, b]).head(5).collect(engine="streaming"),
425
pl.Series("c1", range(5)).to_frame(),
426
)
427
428
429
@pytest.mark.parametrize(
430
("scan", "write"),
431
[
432
(pl.scan_ipc, pl.DataFrame.write_ipc),
433
(pl.scan_parquet, pl.DataFrame.write_parquet),
434
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
435
(
436
pl.scan_csv,
437
pl.DataFrame.write_csv,
438
),
439
],
440
)
441
def test_multiscan_tail(
442
scan: Callable[..., pl.LazyFrame],
443
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
444
) -> None:
445
a = io.BytesIO()
446
b = io.BytesIO()
447
for f in [a, b]:
448
write(pl.Series("c1", range(10)).to_frame(), f)
449
f.seek(0)
450
451
assert_frame_equal(
452
scan([a, b]).tail(5).collect(engine="streaming"),
453
pl.Series("c1", range(5, 10)).to_frame(),
454
)
455
456
457
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
458
def test_multiscan_slice_middle(
459
scan: Callable[..., pl.LazyFrame],
460
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
461
) -> None:
462
fs = [io.BytesIO() for _ in range(13)]
463
for f in fs:
464
write(pl.Series("c1", range(7)).to_frame(), f)
465
f.seek(0)
466
467
offset = 5 * 7 - 5
468
expected = (
469
list(range(2, 7)) # fs[4]
470
+ list(range(7)) # fs[5]
471
+ list(range(5)) # fs[6]
472
)
473
expected_series = [pl.Series("c1", expected)]
474
ri_expected_series = [
475
pl.Series("ri", range(offset, offset + 17), get_index_type())
476
] + expected_series
477
478
assert_frame_equal(
479
scan(fs).slice(offset, 17).collect(engine="streaming"),
480
pl.DataFrame(expected_series),
481
)
482
assert_frame_equal(
483
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
484
pl.DataFrame(ri_expected_series),
485
)
486
487
# Negative slices
488
offset = -(13 * 7 - offset)
489
assert_frame_equal(
490
scan(fs).slice(offset, 17).collect(engine="streaming"),
491
pl.DataFrame(expected_series),
492
)
493
assert_frame_equal(
494
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
495
pl.DataFrame(ri_expected_series),
496
)
497
498
499
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
500
@given(offset=st.integers(-100, 100), length=st.integers(0, 101))
501
def test_multiscan_slice_parametric(
502
scan: Callable[..., pl.LazyFrame],
503
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
504
offset: int,
505
length: int,
506
) -> None:
507
ref = io.BytesIO()
508
write(pl.Series("c1", [i % 7 for i in range(13 * 7)]).to_frame(), ref)
509
ref.seek(0)
510
511
fs = [io.BytesIO() for _ in range(13)]
512
for f in fs:
513
write(pl.Series("c1", range(7)).to_frame(), f)
514
f.seek(0)
515
516
assert_frame_equal(
517
scan(ref).slice(offset, length).collect(),
518
scan(fs).slice(offset, length).collect(engine="streaming"),
519
)
520
521
ref.seek(0)
522
for f in fs:
523
f.seek(0)
524
525
assert_frame_equal(
526
scan(ref, row_index_name="ri", row_index_offset=42)
527
.slice(offset, length)
528
.collect(),
529
scan(fs, row_index_name="ri", row_index_offset=42)
530
.slice(offset, length)
531
.collect(engine="streaming"),
532
)
533
534
assert_frame_equal(
535
scan(ref, row_index_name="ri", row_index_offset=42)
536
.slice(offset, length)
537
.select("ri")
538
.collect(),
539
scan(fs, row_index_name="ri", row_index_offset=42)
540
.slice(offset, length)
541
.select("ri")
542
.collect(engine="streaming"),
543
)
544
545
546
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
547
def test_many_files(scan: Any, write: Any) -> None:
548
f = io.BytesIO()
549
write(pl.DataFrame({"a": [5, 10, 1996]}), f)
550
bs = f.getvalue()
551
552
out = scan([bs] * 1023)
553
554
assert_frame_equal(
555
out.collect(),
556
pl.DataFrame(
557
{
558
"a": [5, 10, 1996] * 1023,
559
}
560
),
561
)
562
563
564
def test_deadlock_stop_requested(plmonkeypatch: PlMonkeyPatch) -> None:
565
df = pl.DataFrame(
566
{
567
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
568
}
569
)
570
571
f = io.BytesIO()
572
df.write_parquet(f, row_group_size=1)
573
574
plmonkeypatch.setenv("POLARS_MAX_THREADS", "2")
575
plmonkeypatch.setenv("POLARS_JOIN_SAMPLE_LIMIT", "1")
576
577
left_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
578
right_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
579
580
left = pl.scan_parquet(left_fs) # type: ignore[arg-type]
581
right = pl.scan_parquet(right_fs) # type: ignore[arg-type]
582
583
left.join(right, pl.col.a == pl.col.a).collect(engine="streaming").height
584
585
586
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
587
def test_deadlock_linearize(scan: Any, write: Any) -> None:
588
df = pl.DataFrame(
589
{
590
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
591
}
592
)
593
594
f = io.BytesIO()
595
write(df, f)
596
fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
597
lf = scan(fs).head(100)
598
599
assert_frame_equal(
600
lf.collect(
601
engine="streaming", optimizations=pl.QueryOptFlags(slice_pushdown=False)
602
),
603
pl.concat([df] * 10),
604
)
605
606
607
@pytest.mark.parametrize(
608
("scan", "write"),
609
SCAN_AND_WRITE_FUNCS,
610
)
611
def test_row_index_filter_22612(scan: Any, write: Any) -> None:
612
df = pl.DataFrame(
613
{
614
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
615
}
616
)
617
618
f = io.BytesIO()
619
620
if write is pl.DataFrame.write_parquet:
621
df.write_parquet(f, row_group_size=5)
622
assert pq.read_metadata(f).num_row_groups == 2
623
else:
624
write(df, f)
625
626
for end in range(2, 10):
627
assert_frame_equal(
628
scan(f)
629
.with_row_index()
630
.filter(pl.col("index") >= end - 2, pl.col("index") <= end)
631
.collect(),
632
df.with_row_index().slice(end - 2, 3),
633
)
634
635
assert_frame_equal(
636
scan(f)
637
.with_row_index()
638
.filter(pl.col("index").is_between(end - 2, end))
639
.collect(),
640
df.with_row_index().slice(end - 2, 3),
641
)
642
643
644
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
645
def test_row_index_name_in_file(scan: Any, write: Any) -> None:
646
f = io.BytesIO()
647
write(pl.DataFrame({"index": 1}), f)
648
649
with pytest.raises(
650
pl.exceptions.DuplicateError,
651
match="cannot add row_index with name 'index': column already exists in file",
652
):
653
scan(f).with_row_index().collect()
654
655
656
def test_extra_columns_not_ignored_22218() -> None:
657
dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "c": 2})]
658
659
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
660
661
dfs[0].write_parquet(files[0])
662
dfs[1].write_parquet(files[1])
663
664
with pytest.raises(
665
pl.exceptions.SchemaError,
666
match=r"extra column in file outside of expected schema: c, hint: specify .*or pass",
667
):
668
pl.scan_parquet(files, missing_columns="insert").select(pl.all()).collect()
669
670
assert_frame_equal(
671
pl.scan_parquet(
672
files,
673
missing_columns="insert",
674
extra_columns="ignore",
675
)
676
.select(pl.all())
677
.collect(),
678
pl.DataFrame({"a": [1, 2], "b": [1, None]}),
679
)
680
681
682
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
683
def test_scan_null_upcast(scan: Any, write: Any) -> None:
684
dfs = [
685
pl.DataFrame({"a": [1, 2, 3]}),
686
pl.select(a=pl.lit(None, dtype=pl.Null)),
687
]
688
689
files = [io.BytesIO(), io.BytesIO()]
690
691
write(dfs[0], files[0])
692
write(dfs[1], files[1])
693
694
# Prevent CSV schema inference from loading as string (it looks at multiple
695
# files).
696
if scan is pl.scan_csv:
697
scan = partial(scan, schema=dfs[0].schema)
698
699
assert_frame_equal(
700
scan(files).collect(),
701
pl.DataFrame({"a": [1, 2, 3, None]}),
702
)
703
704
705
@pytest.mark.parametrize(
706
("scan", "write"),
707
[
708
(pl.scan_ipc, pl.DataFrame.write_ipc),
709
(pl.scan_parquet, pl.DataFrame.write_parquet),
710
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
711
],
712
)
713
def test_scan_null_upcast_to_nested(scan: Any, write: Any) -> None:
714
schema = {"a": pl.List(pl.Struct({"field": pl.Int64}))}
715
716
dfs = [
717
pl.DataFrame(
718
{"a": [[{"field": 1}], [{"field": 2}], []]},
719
schema=schema,
720
),
721
pl.select(a=pl.lit(None, dtype=pl.Null)),
722
]
723
724
files = [io.BytesIO(), io.BytesIO()]
725
726
write(dfs[0], files[0])
727
write(dfs[1], files[1])
728
729
# Prevent CSV schema inference from loading as string (it looks at multiple
730
# files).
731
if scan is pl.scan_csv:
732
scan = partial(scan, schema=schema)
733
734
assert_frame_equal(
735
scan(files).collect(),
736
pl.DataFrame(
737
{"a": [[{"field": 1}], [{"field": 2}], [], None]},
738
schema=schema,
739
),
740
)
741
742
743
@pytest.mark.parametrize(
744
("scan", "write"),
745
[
746
(pl.scan_parquet, pl.DataFrame.write_parquet),
747
],
748
)
749
@pytest.mark.parametrize(
750
"prefix",
751
[
752
"",
753
"file:" if sys.platform != "win32" else "file:/",
754
"file://" if sys.platform != "win32" else "file:///",
755
],
756
)
757
@pytest.mark.parametrize("use_glob", [True, False])
758
def test_scan_ignore_hidden_files_21762(
759
tmp_path: Path, scan: Any, write: Any, use_glob: bool, prefix: str
760
) -> None:
761
file_names: list[str] = ["a.ext", "_a.ext", ".a.ext", "a_.ext"]
762
763
for file_name in file_names:
764
write(pl.DataFrame({"rel_path": file_name}), tmp_path / file_name)
765
766
(tmp_path / "folder").mkdir()
767
768
for file_name in file_names:
769
write(
770
pl.DataFrame({"rel_path": f"folder/{file_name}"}),
771
tmp_path / "folder" / file_name,
772
)
773
774
(tmp_path / "_folder").mkdir()
775
776
for file_name in file_names:
777
write(
778
pl.DataFrame({"rel_path": f"_folder/{file_name}"}),
779
tmp_path / "_folder" / file_name,
780
)
781
782
suffix = "/**/*.ext" if use_glob else "/" if prefix.startswith("file:") else ""
783
root = f"{prefix}{tmp_path}{suffix}"
784
785
assert_frame_equal(
786
scan(root).sort("*"),
787
pl.LazyFrame(
788
{
789
"rel_path": [
790
".a.ext",
791
"_a.ext",
792
"_folder/.a.ext",
793
"_folder/_a.ext",
794
"_folder/a.ext",
795
"_folder/a_.ext",
796
"a.ext",
797
"a_.ext",
798
"folder/.a.ext",
799
"folder/_a.ext",
800
"folder/a.ext",
801
"folder/a_.ext",
802
]
803
}
804
),
805
)
806
807
assert_frame_equal(
808
scan(root, hidden_file_prefix=".").sort("*"),
809
pl.LazyFrame(
810
{
811
"rel_path": [
812
"_a.ext",
813
"_folder/_a.ext",
814
"_folder/a.ext",
815
"_folder/a_.ext",
816
"a.ext",
817
"a_.ext",
818
"folder/_a.ext",
819
"folder/a.ext",
820
"folder/a_.ext",
821
]
822
}
823
),
824
)
825
826
assert_frame_equal(
827
scan(root, hidden_file_prefix=[".", "_"]).sort("*"),
828
pl.LazyFrame(
829
{
830
"rel_path": [
831
"_folder/a.ext",
832
"_folder/a_.ext",
833
"a.ext",
834
"a_.ext",
835
"folder/a.ext",
836
"folder/a_.ext",
837
]
838
}
839
),
840
)
841
842
assert_frame_equal(
843
scan(root, hidden_file_prefix=(".", "_")).sort("*"),
844
pl.LazyFrame(
845
{
846
"rel_path": [
847
"_folder/a.ext",
848
"_folder/a_.ext",
849
"a.ext",
850
"a_.ext",
851
"folder/a.ext",
852
"folder/a_.ext",
853
]
854
}
855
),
856
)
857
858
# Top-level glob only
859
root = f"{tmp_path}/*.ext"
860
861
assert_frame_equal(
862
scan(root).sort("*"),
863
pl.LazyFrame(
864
{
865
"rel_path": [
866
".a.ext",
867
"_a.ext",
868
"a.ext",
869
"a_.ext",
870
]
871
}
872
),
873
)
874
875
assert_frame_equal(
876
scan(root, hidden_file_prefix=".").sort("*"),
877
pl.LazyFrame(
878
{
879
"rel_path": [
880
"_a.ext",
881
"a.ext",
882
"a_.ext",
883
]
884
}
885
),
886
)
887
888
assert_frame_equal(
889
scan(root, hidden_file_prefix=[".", "_"]).sort("*"),
890
pl.LazyFrame(
891
{
892
"rel_path": [
893
"a.ext",
894
"a_.ext",
895
]
896
}
897
),
898
)
899
900
# Direct file passed
901
with pytest.raises(pl.exceptions.ComputeError, match="expanded paths were empty"):
902
scan(tmp_path / "_a.ext", hidden_file_prefix="_").collect()
903
904
905
def test_row_count_estimate_multifile(io_files_path: Path) -> None:
906
src = io_files_path / "foods*.parquet"
907
# test that it doesn't check only the first file
908
assert "ESTIMATED ROWS: 54" in pl.scan_parquet(src).explain()
909
910
911
@pytest.mark.parametrize(
912
("scan", "write", "ext"),
913
[
914
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),
915
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),
916
],
917
)
918
@pytest.mark.parametrize(
919
("predicate", "expected_indices"),
920
[
921
((pl.col.x == 1) & True, [0]),
922
(True & (pl.col.x == 1), [0]),
923
],
924
)
925
@pytest.mark.write_disk
926
def test_hive_predicate_filtering_edge_case_25630(
927
tmp_path: Path,
928
scan: Callable[..., pl.LazyFrame],
929
write: Callable[[pl.DataFrame, Path], Any],
930
ext: str,
931
predicate: pl.Expr,
932
expected_indices: list[int],
933
) -> None:
934
df = pl.DataFrame({"x": [1, 2, 3], "y": [0, 1, 1]}).with_row_index()
935
936
(tmp_path / "y=0").mkdir()
937
(tmp_path / "y=1").mkdir()
938
939
# previously we could panic if hive columns were all filtered out of the projection
940
write(df.filter(pl.col.y == 0).drop("y"), tmp_path / "y=0" / f"data.{ext}")
941
write(df.filter(pl.col.y == 1).drop("y"), tmp_path / "y=1" / f"data.{ext}")
942
943
res = scan(tmp_path).filter(predicate).select("index").collect(engine="streaming")
944
expected = pl.DataFrame(
945
data={"index": expected_indices},
946
schema={"index": pl.get_index_type()},
947
)
948
assert_frame_equal(res, expected)
949
950