Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_scan.py
6939 views
1
from __future__ import annotations
2
3
import io
4
import sys
5
from dataclasses import dataclass
6
from datetime import datetime
7
from functools import partial
8
from math import ceil
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, Callable
11
12
import pytest
13
14
import polars as pl
15
from polars.testing.asserts.frame import assert_frame_equal
16
17
if TYPE_CHECKING:
18
from polars._typing import SchemaDict
19
20
21
@dataclass
22
class _RowIndex:
23
name: str = "index"
24
offset: int = 0
25
26
27
def _enable_force_async(monkeypatch: pytest.MonkeyPatch) -> None:
28
"""Modifies the provided monkeypatch context."""
29
monkeypatch.setenv("POLARS_VERBOSE", "1")
30
monkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
31
32
33
def _scan(
34
file_path: Path,
35
schema: SchemaDict | None = None,
36
row_index: _RowIndex | None = None,
37
) -> pl.LazyFrame:
38
suffix = file_path.suffix
39
row_index_name = None if row_index is None else row_index.name
40
row_index_offset = 0 if row_index is None else row_index.offset
41
42
if (
43
scan_func := {
44
".ipc": pl.scan_ipc,
45
".parquet": pl.scan_parquet,
46
".csv": pl.scan_csv,
47
".ndjson": pl.scan_ndjson,
48
}.get(suffix)
49
) is not None: # fmt: skip
50
result = scan_func(
51
file_path,
52
row_index_name=row_index_name,
53
row_index_offset=row_index_offset,
54
) # type: ignore[operator]
55
56
else:
57
msg = f"Unknown suffix {suffix}"
58
raise NotImplementedError(msg)
59
60
return result # type: ignore[no-any-return]
61
62
63
def _write(df: pl.DataFrame, file_path: Path) -> None:
64
suffix = file_path.suffix
65
66
if (
67
write_func := {
68
".ipc": pl.DataFrame.write_ipc,
69
".parquet": pl.DataFrame.write_parquet,
70
".csv": pl.DataFrame.write_csv,
71
".ndjson": pl.DataFrame.write_ndjson,
72
}.get(suffix)
73
) is not None: # fmt: skip
74
return write_func(df, file_path) # type: ignore[operator, no-any-return]
75
76
msg = f"Unknown suffix {suffix}"
77
raise NotImplementedError(msg)
78
79
80
@pytest.fixture(
81
scope="session",
82
params=["csv", "ipc", "parquet", "ndjson"],
83
)
84
def data_file_extension(request: pytest.FixtureRequest) -> str:
85
return f".{request.param}"
86
87
88
@pytest.fixture(scope="session")
89
def session_tmp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path:
90
return tmp_path_factory.mktemp("polars-test")
91
92
93
@pytest.fixture(
94
params=[False, True],
95
ids=["sync", "async"],
96
)
97
def force_async(
98
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
99
) -> bool:
100
value: bool = request.param
101
return value
102
103
104
@dataclass
105
class _DataFile:
106
path: Path
107
df: pl.DataFrame
108
109
110
def df_with_chunk_size_limit(df: pl.DataFrame, limit: int) -> pl.DataFrame:
111
return pl.concat(
112
(
113
df.slice(i * limit, min(limit, df.height - i * limit))
114
for i in range(ceil(df.height / limit))
115
),
116
rechunk=False,
117
)
118
119
120
@pytest.fixture(scope="session")
121
def data_file_single(session_tmp_dir: Path, data_file_extension: str) -> _DataFile:
122
max_rows_per_batch = 727
123
file_path = (session_tmp_dir / "data").with_suffix(data_file_extension)
124
df = pl.DataFrame(
125
{
126
"sequence": range(10000),
127
}
128
)
129
assert max_rows_per_batch < df.height
130
_write(df_with_chunk_size_limit(df, max_rows_per_batch), file_path)
131
return _DataFile(path=file_path, df=df)
132
133
134
@pytest.fixture(scope="session")
135
def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile:
136
max_rows_per_batch = 200
137
row_counts = [
138
100, 186, 95, 185, 90, 84, 115, 81, 87, 217, 126, 85, 98, 122, 129, 122, 1089, 82,
139
234, 86, 93, 90, 91, 263, 87, 126, 86, 161, 191, 1368, 403, 192, 102, 98, 115, 81,
140
111, 305, 92, 534, 431, 150, 90, 128, 152, 118, 127, 124, 229, 368, 81,
141
] # fmt: skip
142
assert sum(row_counts) == 10000
143
144
# Make sure we pad file names with enough zeros to ensure correct
145
# lexicographical ordering.
146
assert len(row_counts) < 100
147
148
# Make sure that some of our data frames consist of multiple chunks which
149
# affects the output of certain file formats.
150
assert any(row_count > max_rows_per_batch for row_count in row_counts)
151
df = pl.DataFrame(
152
{
153
"sequence": range(10000),
154
}
155
)
156
157
row_offset = 0
158
for index, row_count in enumerate(row_counts):
159
file_path = (session_tmp_dir / f"data_{index:02}").with_suffix(
160
data_file_extension
161
)
162
_write(
163
df_with_chunk_size_limit(
164
df.slice(row_offset, row_count), max_rows_per_batch
165
),
166
file_path,
167
)
168
row_offset += row_count
169
return _DataFile(
170
path=(session_tmp_dir / "data_*").with_suffix(data_file_extension), df=df
171
)
172
173
174
@pytest.fixture(scope="session", params=["single", "glob"])
175
def data_file(
176
request: pytest.FixtureRequest,
177
data_file_single: _DataFile,
178
data_file_glob: _DataFile,
179
) -> _DataFile:
180
if request.param == "single":
181
return data_file_single
182
if request.param == "glob":
183
return data_file_glob
184
raise NotImplementedError()
185
186
187
@pytest.mark.write_disk
188
def test_scan(
189
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
190
) -> None:
191
if force_async:
192
_enable_force_async(monkeypatch)
193
194
df = _scan(data_file.path, data_file.df.schema).collect()
195
196
assert_frame_equal(df, data_file.df)
197
198
199
@pytest.mark.write_disk
200
def test_scan_with_limit(
201
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
202
) -> None:
203
if force_async:
204
_enable_force_async(monkeypatch)
205
206
df = _scan(data_file.path, data_file.df.schema).limit(4483).collect()
207
208
assert_frame_equal(
209
df,
210
pl.DataFrame(
211
{
212
"sequence": range(4483),
213
}
214
),
215
)
216
217
218
@pytest.mark.write_disk
219
def test_scan_with_filter(
220
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
221
) -> None:
222
if force_async:
223
_enable_force_async(monkeypatch)
224
225
df = (
226
_scan(data_file.path, data_file.df.schema)
227
.filter(pl.col("sequence") % 2 == 0)
228
.collect()
229
)
230
231
assert_frame_equal(
232
df,
233
pl.DataFrame(
234
{
235
"sequence": (2 * x for x in range(5000)),
236
}
237
),
238
)
239
240
241
@pytest.mark.write_disk
242
def test_scan_with_filter_and_limit(
243
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
244
) -> None:
245
if force_async:
246
_enable_force_async(monkeypatch)
247
248
df = (
249
_scan(data_file.path, data_file.df.schema)
250
.filter(pl.col("sequence") % 2 == 0)
251
.limit(4483)
252
.collect()
253
)
254
255
assert_frame_equal(
256
df,
257
pl.DataFrame(
258
{
259
"sequence": (2 * x for x in range(4483)),
260
},
261
),
262
)
263
264
265
@pytest.mark.write_disk
266
def test_scan_with_limit_and_filter(
267
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
268
) -> None:
269
if force_async:
270
_enable_force_async(monkeypatch)
271
272
df = (
273
_scan(data_file.path, data_file.df.schema)
274
.limit(4483)
275
.filter(pl.col("sequence") % 2 == 0)
276
.collect()
277
)
278
279
assert_frame_equal(
280
df,
281
pl.DataFrame(
282
{
283
"sequence": (2 * x for x in range(2242)),
284
},
285
),
286
)
287
288
289
@pytest.mark.write_disk
290
def test_scan_with_row_index_and_limit(
291
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
292
) -> None:
293
if force_async:
294
_enable_force_async(monkeypatch)
295
296
df = (
297
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
298
.limit(4483)
299
.collect()
300
)
301
302
assert_frame_equal(
303
df,
304
pl.DataFrame(
305
{
306
"index": range(4483),
307
"sequence": range(4483),
308
},
309
schema_overrides={"index": pl.UInt32},
310
),
311
)
312
313
314
@pytest.mark.write_disk
315
def test_scan_with_row_index_and_filter(
316
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
317
) -> None:
318
if force_async:
319
_enable_force_async(monkeypatch)
320
321
df = (
322
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
323
.filter(pl.col("sequence") % 2 == 0)
324
.collect()
325
)
326
327
assert_frame_equal(
328
df,
329
pl.DataFrame(
330
{
331
"index": (2 * x for x in range(5000)),
332
"sequence": (2 * x for x in range(5000)),
333
},
334
schema_overrides={"index": pl.UInt32},
335
),
336
)
337
338
339
@pytest.mark.write_disk
340
def test_scan_with_row_index_limit_and_filter(
341
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
342
) -> None:
343
if force_async:
344
_enable_force_async(monkeypatch)
345
346
df = (
347
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
348
.limit(4483)
349
.filter(pl.col("sequence") % 2 == 0)
350
.collect()
351
)
352
353
assert_frame_equal(
354
df,
355
pl.DataFrame(
356
{
357
"index": (2 * x for x in range(2242)),
358
"sequence": (2 * x for x in range(2242)),
359
},
360
schema_overrides={"index": pl.UInt32},
361
),
362
)
363
364
365
@pytest.mark.write_disk
366
def test_scan_with_row_index_projected_out(
367
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
368
) -> None:
369
if data_file.path.suffix == ".csv" and force_async:
370
pytest.skip(reason="async reading of .csv not yet implemented")
371
372
if force_async:
373
_enable_force_async(monkeypatch)
374
375
subset = next(iter(data_file.df.schema.keys()))
376
df = (
377
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
378
.select(subset)
379
.collect()
380
)
381
382
assert_frame_equal(df, data_file.df.select(subset))
383
384
385
@pytest.mark.write_disk
386
def test_scan_with_row_index_filter_and_limit(
387
capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool
388
) -> None:
389
if data_file.path.suffix == ".csv" and force_async:
390
pytest.skip(reason="async reading of .csv not yet implemented")
391
392
if force_async:
393
_enable_force_async(monkeypatch)
394
395
df = (
396
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
397
.filter(pl.col("sequence") % 2 == 0)
398
.limit(4483)
399
.collect()
400
)
401
402
assert_frame_equal(
403
df,
404
pl.DataFrame(
405
{
406
"index": (2 * x for x in range(4483)),
407
"sequence": (2 * x for x in range(4483)),
408
},
409
schema_overrides={"index": pl.UInt32},
410
),
411
)
412
413
414
@pytest.mark.write_disk
415
@pytest.mark.parametrize(
416
("scan_func", "write_func"),
417
[
418
(pl.scan_parquet, pl.DataFrame.write_parquet),
419
(pl.scan_ipc, pl.DataFrame.write_ipc),
420
(pl.scan_csv, pl.DataFrame.write_csv),
421
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
422
],
423
)
424
@pytest.mark.parametrize(
425
"streaming",
426
[True, False],
427
)
428
def test_scan_limit_0_does_not_panic(
429
tmp_path: Path,
430
scan_func: Callable[[Any], pl.LazyFrame],
431
write_func: Callable[[pl.DataFrame, Path], None],
432
streaming: bool,
433
) -> None:
434
tmp_path.mkdir(exist_ok=True)
435
path = tmp_path / "data.bin"
436
df = pl.DataFrame({"x": 1})
437
write_func(df, path)
438
assert_frame_equal(
439
scan_func(path)
440
.head(0)
441
.collect(engine="streaming" if streaming else "in-memory"),
442
df.clear(),
443
)
444
445
446
@pytest.mark.write_disk
447
@pytest.mark.parametrize(
448
("scan_func", "write_func"),
449
[
450
(pl.scan_csv, pl.DataFrame.write_csv),
451
(pl.scan_parquet, pl.DataFrame.write_parquet),
452
(pl.scan_ipc, pl.DataFrame.write_ipc),
453
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
454
],
455
)
456
@pytest.mark.parametrize(
457
"glob",
458
[True, False],
459
)
460
def test_scan_directory(
461
tmp_path: Path,
462
scan_func: Callable[..., pl.LazyFrame],
463
write_func: Callable[[pl.DataFrame, Path], None],
464
glob: bool,
465
) -> None:
466
tmp_path.mkdir(exist_ok=True)
467
468
dfs: list[pl.DataFrame] = [
469
pl.DataFrame({"a": [0, 0, 0, 0, 0]}),
470
pl.DataFrame({"a": [1, 1, 1, 1, 1]}),
471
pl.DataFrame({"a": [2, 2, 2, 2, 2]}),
472
]
473
474
paths = [
475
tmp_path / "0.bin",
476
tmp_path / "1.bin",
477
tmp_path / "dir/data.bin",
478
]
479
480
for df, path in zip(dfs, paths):
481
path.parent.mkdir(exist_ok=True)
482
write_func(df, path)
483
484
df = pl.concat(dfs)
485
486
scan = scan_func
487
488
if scan_func in [pl.scan_csv, pl.scan_ndjson]:
489
scan = partial(scan, schema=df.schema)
490
491
if scan_func is pl.scan_parquet:
492
scan = partial(scan, glob=glob)
493
494
out = scan(tmp_path).collect()
495
assert_frame_equal(out, df)
496
497
498
@pytest.mark.write_disk
499
def test_scan_glob_excludes_directories(tmp_path: Path) -> None:
500
for dir in ["dir1", "dir2", "dir3"]:
501
(tmp_path / dir).mkdir()
502
503
df = pl.DataFrame({"a": [1, 2, 3]})
504
505
df.write_parquet(tmp_path / "dir1/data.bin")
506
df.write_parquet(tmp_path / "dir2/data.parquet")
507
df.write_parquet(tmp_path / "data.parquet")
508
509
assert_frame_equal(pl.scan_parquet(tmp_path / "**/*.bin").collect(), df)
510
assert_frame_equal(pl.scan_parquet(tmp_path / "**/data*.bin").collect(), df)
511
assert_frame_equal(
512
pl.scan_parquet(tmp_path / "**/*").collect(), pl.concat(3 * [df])
513
)
514
assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df)
515
516
517
@pytest.mark.parametrize("file_name", ["a b", "a %25 b"])
518
@pytest.mark.write_disk
519
def test_scan_async_whitespace_in_path(
520
tmp_path: Path, monkeypatch: Any, file_name: str
521
) -> None:
522
monkeypatch.setenv("POLARS_FORCE_ASYNC", "1")
523
tmp_path.mkdir(exist_ok=True)
524
525
path = tmp_path / f"{file_name}.parquet"
526
df = pl.DataFrame({"x": 1})
527
df.write_parquet(path)
528
assert_frame_equal(pl.scan_parquet(path).collect(), df)
529
assert_frame_equal(pl.scan_parquet(tmp_path).collect(), df)
530
assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df)
531
assert_frame_equal(pl.scan_parquet(tmp_path / "*.parquet").collect(), df)
532
path.unlink()
533
534
535
@pytest.mark.write_disk
536
def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None:
537
tmp_path.mkdir(exist_ok=True)
538
539
df = pl.DataFrame({"x": 1})
540
df.write_parquet(tmp_path / "data.parquet")
541
(tmp_path / "empty").touch()
542
543
assert_frame_equal(pl.scan_parquet(tmp_path).collect(), df)
544
assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df)
545
546
547
@pytest.mark.write_disk
548
def test_path_expansion_empty_directory_does_not_panic(tmp_path: Path) -> None:
549
tmp_path.mkdir(exist_ok=True)
550
551
with pytest.raises(pl.exceptions.ComputeError):
552
pl.scan_parquet(tmp_path).collect()
553
554
with pytest.raises(pl.exceptions.ComputeError):
555
pl.scan_parquet(tmp_path / "**/*").collect()
556
557
558
@pytest.mark.write_disk
559
def test_scan_single_dir_differing_file_extensions_raises_17436(tmp_path: Path) -> None:
560
tmp_path.mkdir(exist_ok=True)
561
562
df = pl.DataFrame({"x": 1})
563
df.write_parquet(tmp_path / "data.parquet")
564
df.write_ipc(tmp_path / "data.ipc")
565
566
with pytest.raises(
567
pl.exceptions.InvalidOperationError, match="different file extensions"
568
):
569
pl.scan_parquet(tmp_path).collect()
570
571
for lf in [
572
pl.scan_parquet(tmp_path / "*.parquet"),
573
pl.scan_ipc(tmp_path / "*.ipc"),
574
]:
575
assert_frame_equal(lf.collect(), df)
576
577
# Ensure passing a glob doesn't trigger file extension checking
578
with pytest.raises(
579
pl.exceptions.ComputeError,
580
match="parquet: File out of specification: The file must end with PAR1",
581
):
582
pl.scan_parquet(tmp_path / "*").collect()
583
584
585
@pytest.mark.parametrize("format", ["parquet", "csv", "ndjson", "ipc"])
586
def test_scan_nonexistent_path(format: str) -> None:
587
path_str = f"my-nonexistent-data.{format}"
588
path = Path(path_str)
589
assert not path.exists()
590
591
scan_function = getattr(pl, f"scan_{format}")
592
593
# Just calling the scan function should not raise any errors
594
result = scan_function(path)
595
assert isinstance(result, pl.LazyFrame)
596
597
# Upon collection, it should fail
598
with pytest.raises(FileNotFoundError):
599
result.collect()
600
601
602
@pytest.mark.write_disk
603
@pytest.mark.parametrize(
604
("scan_func", "write_func"),
605
[
606
(pl.scan_parquet, pl.DataFrame.write_parquet),
607
(pl.scan_ipc, pl.DataFrame.write_ipc),
608
(pl.scan_csv, pl.DataFrame.write_csv),
609
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
610
],
611
)
612
@pytest.mark.parametrize(
613
"streaming",
614
[True, False],
615
)
616
def test_scan_include_file_paths(
617
tmp_path: Path,
618
scan_func: Callable[..., pl.LazyFrame],
619
write_func: Callable[[pl.DataFrame, Path], None],
620
streaming: bool,
621
) -> None:
622
tmp_path.mkdir(exist_ok=True)
623
dfs: list[pl.DataFrame] = []
624
625
for x in ["1", "2"]:
626
path = Path(f"{tmp_path}/{x}.bin").absolute()
627
dfs.append(pl.DataFrame({"x": 10 * [x]}).with_columns(path=pl.lit(str(path))))
628
write_func(dfs[-1].drop("path"), path)
629
630
df = pl.concat(dfs)
631
assert df.columns == ["x", "path"]
632
633
with pytest.raises(
634
pl.exceptions.DuplicateError,
635
match=r'column name for file paths "x" conflicts with column name from file',
636
):
637
scan_func(tmp_path, include_file_paths="x").collect(
638
engine="streaming" if streaming else "in-memory"
639
)
640
641
f = scan_func
642
if scan_func in [pl.scan_csv, pl.scan_ndjson]:
643
f = partial(f, schema=df.drop("path").schema)
644
645
lf: pl.LazyFrame = f(tmp_path, include_file_paths="path")
646
assert_frame_equal(lf.collect(engine="streaming" if streaming else "in-memory"), df)
647
648
# Test projecting only the path column
649
q = lf.select("path")
650
assert q.collect_schema() == {"path": pl.String}
651
assert_frame_equal(
652
q.collect(engine="streaming" if streaming else "in-memory"),
653
df.select("path"),
654
)
655
656
q = q.select("path").head(3)
657
assert q.collect_schema() == {"path": pl.String}
658
assert_frame_equal(
659
q.collect(engine="streaming" if streaming else "in-memory"),
660
df.select("path").head(3),
661
)
662
663
# Test predicates
664
for predicate in [pl.col("path") != pl.col("x"), pl.col("path") != ""]:
665
assert_frame_equal(
666
lf.filter(predicate).collect(
667
engine="streaming" if streaming else "in-memory"
668
),
669
df,
670
)
671
672
# Test codepaths that materialize empty DataFrames
673
assert_frame_equal(
674
lf.head(0).collect(engine="streaming" if streaming else "in-memory"),
675
df.head(0),
676
)
677
678
679
@pytest.mark.write_disk
680
def test_async_path_expansion_bracket_17629(tmp_path: Path) -> None:
681
path = tmp_path / "data.parquet"
682
683
df = pl.DataFrame({"x": 1})
684
df.write_parquet(path)
685
686
assert_frame_equal(pl.scan_parquet(tmp_path / "[d]ata.parquet").collect(), df)
687
688
689
@pytest.mark.parametrize(
690
"method",
691
["parquet", "csv", "ipc", "ndjson"],
692
)
693
@pytest.mark.may_fail_auto_streaming # unsupported negative slice offset -1 for CSV source
694
def test_scan_in_memory(method: str) -> None:
695
f = io.BytesIO()
696
df = pl.DataFrame(
697
{
698
"a": [1, 2, 3],
699
"b": ["x", "y", "z"],
700
}
701
)
702
703
(getattr(df, f"write_{method}"))(f)
704
705
f.seek(0)
706
result = (getattr(pl, f"scan_{method}"))(f).collect()
707
assert_frame_equal(df, result)
708
709
f.seek(0)
710
result = (getattr(pl, f"scan_{method}"))(f).slice(1, 2).collect()
711
assert_frame_equal(df.slice(1, 2), result)
712
713
f.seek(0)
714
result = (getattr(pl, f"scan_{method}"))(f).slice(-1, 1).collect()
715
assert_frame_equal(df.slice(-1, 1), result)
716
717
g = io.BytesIO()
718
(getattr(df, f"write_{method}"))(g)
719
720
f.seek(0)
721
g.seek(0)
722
result = (getattr(pl, f"scan_{method}"))([f, g]).collect()
723
assert_frame_equal(df.vstack(df), result)
724
725
f.seek(0)
726
g.seek(0)
727
result = (getattr(pl, f"scan_{method}"))([f, g]).slice(1, 2).collect()
728
assert_frame_equal(df.vstack(df).slice(1, 2), result)
729
730
f.seek(0)
731
g.seek(0)
732
result = (getattr(pl, f"scan_{method}"))([f, g]).slice(-1, 1).collect()
733
assert_frame_equal(df.vstack(df).slice(-1, 1), result)
734
735
736
def test_scan_pyobject_zero_copy_buffer_mutate() -> None:
737
f = io.BytesIO()
738
739
df = pl.DataFrame({"x": [1, 2, 3, 4, 5]})
740
df.write_ipc(f)
741
f.seek(0)
742
743
q = pl.scan_ipc(f)
744
assert_frame_equal(q.collect(), df)
745
746
f.write(b"AAA")
747
assert_frame_equal(q.collect(), df)
748
749
750
@pytest.mark.parametrize(
751
"method",
752
["csv", "ndjson"],
753
)
754
def test_scan_stringio(method: str) -> None:
755
f = io.StringIO()
756
df = pl.DataFrame(
757
{
758
"a": [1, 2, 3],
759
"b": ["x", "y", "z"],
760
}
761
)
762
763
(getattr(df, f"write_{method}"))(f)
764
765
f.seek(0)
766
result = (getattr(pl, f"scan_{method}"))(f).collect()
767
assert_frame_equal(df, result)
768
769
g = io.StringIO()
770
(getattr(df, f"write_{method}"))(g)
771
772
f.seek(0)
773
g.seek(0)
774
result = (getattr(pl, f"scan_{method}"))([f, g]).collect()
775
assert_frame_equal(df.vstack(df), result)
776
777
778
def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None:
779
lf = pl.scan_csv(io.BytesIO(b"a\n1\n2\n3"))
780
781
lf.collect()
782
783
out = lf.with_row_index().collect()
784
785
assert_frame_equal(
786
out,
787
pl.DataFrame(
788
{"index": [0, 1, 2], "a": [1, 2, 3]},
789
schema={"index": pl.UInt32, "a": pl.Int64},
790
),
791
)
792
793
794
def test_scan_include_file_paths_respects_projection_pushdown() -> None:
795
q = pl.scan_csv(b"a,b,c\na1,b1,c1", include_file_paths="path_name").select(
796
["a", "b"]
797
)
798
799
assert_frame_equal(q.collect(), pl.DataFrame({"a": "a1", "b": "b1"}))
800
801
802
def test_streaming_scan_csv_include_file_paths_18257(io_files_path: Path) -> None:
803
lf = pl.scan_csv(
804
io_files_path / "foods1.csv",
805
include_file_paths="path",
806
).select("category", "path")
807
808
assert lf.collect(engine="streaming").columns == ["category", "path"]
809
810
811
def test_streaming_scan_csv_with_row_index_19172(io_files_path: Path) -> None:
812
lf = (
813
pl.scan_csv(io_files_path / "foods1.csv", infer_schema=False)
814
.with_row_index()
815
.select("calories", "index")
816
.head(1)
817
)
818
819
assert_frame_equal(
820
lf.collect(engine="streaming"),
821
pl.DataFrame(
822
{"calories": "45", "index": 0},
823
schema={"calories": pl.String, "index": pl.UInt32},
824
),
825
)
826
827
828
@pytest.mark.write_disk
829
def test_predicate_hive_pruning_with_cast(tmp_path: Path) -> None:
830
tmp_path.mkdir(exist_ok=True)
831
832
df = pl.DataFrame({"x": 1})
833
834
(p := (tmp_path / "date=2024-01-01")).mkdir()
835
836
df.write_parquet(p / "1")
837
838
(p := (tmp_path / "date=2024-01-02")).mkdir()
839
840
# Write an invalid parquet file that will cause errors if polars attempts to
841
# read it.
842
# This works because `scan_parquet()` only looks at the first file during
843
# schema inference.
844
(p / "1").write_text("not a parquet file")
845
846
expect = pl.DataFrame({"x": 1, "date": datetime(2024, 1, 1).date()})
847
848
lf = pl.scan_parquet(tmp_path)
849
850
q = lf.filter(pl.col("date") < datetime(2024, 1, 2).date())
851
852
assert_frame_equal(q.collect(), expect)
853
854
# This filter expr with stprtime is effectively what LazyFrame.sql()
855
# generates
856
q = lf.filter(
857
pl.col("date")
858
< pl.lit("2024-01-02").str.strptime(
859
dtype=pl.Date, format="%Y-%m-%d", ambiguous="latest"
860
)
861
)
862
863
assert_frame_equal(q.collect(), expect)
864
865
q = lf.sql("select * from self where date < '2024-01-02'")
866
print(q.explain())
867
assert_frame_equal(q.collect(), expect)
868
869
870
def test_predicate_stats_eval_nested_binary() -> None:
871
bufs: list[bytes] = []
872
873
for i in range(10):
874
b = io.BytesIO()
875
pl.DataFrame({"x": i}).write_parquet(b)
876
b.seek(0)
877
bufs.append(b.read())
878
879
assert_frame_equal(
880
(
881
pl.scan_parquet(bufs)
882
.filter(pl.col("x") % 2 == 0)
883
.collect(optimizations=pl.QueryOptFlags.none())
884
),
885
pl.DataFrame({"x": [0, 2, 4, 6, 8]}),
886
)
887
888
assert_frame_equal(
889
(
890
pl.scan_parquet(bufs)
891
# The literal eval depth limit is 4 -
892
# * crates/polars-expr/src/expressions/mod.rs::PhysicalExpr::evaluate_inline
893
.filter(pl.col("x") == pl.lit("222").str.slice(0, 1).cast(pl.Int64))
894
.collect()
895
),
896
pl.DataFrame({"x": [2]}),
897
)
898
899
900
@pytest.mark.slow
901
@pytest.mark.parametrize("streaming", [True, False])
902
def test_scan_csv_bytesio_memory_usage(
903
streaming: bool,
904
# memory_usage_without_pyarrow: MemoryUsage,
905
) -> None:
906
# memory_usage = memory_usage_without_pyarrow
907
908
# Create CSV that is ~6-7 MB in size:
909
f = io.BytesIO()
910
df = pl.DataFrame({"mydata": pl.int_range(0, 1_000_000, eager=True)})
911
df.write_csv(f)
912
# assert 6_000_000 < f.tell() < 7_000_000
913
f.seek(0, 0)
914
915
# A lazy scan shouldn't make a full copy of the data:
916
# starting_memory = memory_usage.get_current()
917
assert (
918
pl.scan_csv(f)
919
.filter(pl.col("mydata") == 999_999)
920
.collect(engine="streaming" if streaming else "in-memory")
921
.item()
922
== 999_999
923
)
924
# assert memory_usage.get_peak() - starting_memory < 1_000_000
925
926
927
@pytest.mark.parametrize(
928
"scan_type",
929
[
930
(pl.DataFrame.write_parquet, pl.scan_parquet),
931
(pl.DataFrame.write_ipc, pl.scan_ipc),
932
(pl.DataFrame.write_csv, pl.scan_csv),
933
(pl.DataFrame.write_ndjson, pl.scan_ndjson),
934
],
935
)
936
def test_only_project_row_index(scan_type: tuple[Any, Any]) -> None:
937
write, scan = scan_type
938
939
f = io.BytesIO()
940
df = pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)])
941
write(df, f)
942
943
f.seek(0)
944
s = scan(f, row_index_name="row_index", row_index_offset=42)
945
946
assert_frame_equal(
947
s.select("row_index").collect(),
948
pl.DataFrame({"row_index": [42, 43, 44]}),
949
check_dtypes=False,
950
)
951
952
953
@pytest.mark.parametrize(
954
"scan_type",
955
[
956
(pl.DataFrame.write_parquet, pl.scan_parquet),
957
(pl.DataFrame.write_ipc, pl.scan_ipc),
958
(pl.DataFrame.write_csv, pl.scan_csv),
959
(pl.DataFrame.write_ndjson, pl.scan_ndjson),
960
],
961
)
962
def test_only_project_include_file_paths(scan_type: tuple[Any, Any]) -> None:
963
write, scan = scan_type
964
965
f = io.BytesIO()
966
df = pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)])
967
write(df, f)
968
969
f.seek(0)
970
s = scan(f, include_file_paths="file_path")
971
972
# The exact value for in-memory buffers is undefined
973
c = s.select("file_path").collect()
974
assert c.height == 3
975
assert c.columns == ["file_path"]
976
977
978
@pytest.mark.parametrize(
979
"scan_type",
980
[
981
(pl.DataFrame.write_parquet, pl.scan_parquet),
982
pytest.param(
983
(pl.DataFrame.write_ipc, pl.scan_ipc),
984
marks=pytest.mark.xfail(
985
reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166"
986
),
987
),
988
pytest.param(
989
(pl.DataFrame.write_csv, pl.scan_csv),
990
marks=pytest.mark.xfail(
991
reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166"
992
),
993
),
994
pytest.param(
995
(pl.DataFrame.write_ndjson, pl.scan_ndjson),
996
marks=pytest.mark.xfail(
997
reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166"
998
),
999
),
1000
],
1001
)
1002
def test_only_project_missing(scan_type: tuple[Any, Any]) -> None:
1003
write, scan = scan_type
1004
1005
f = io.BytesIO()
1006
g = io.BytesIO()
1007
write(
1008
pl.DataFrame(
1009
[pl.Series("a", [], pl.UInt32), pl.Series("missing", [], pl.Int32)]
1010
),
1011
f,
1012
)
1013
write(pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)]), g)
1014
1015
f.seek(0)
1016
g.seek(0)
1017
s = scan([f, g], missing_columns="insert")
1018
1019
assert_frame_equal(
1020
s.select("missing").collect(),
1021
pl.DataFrame([pl.Series("missing", [None, None, None], pl.Int32)]),
1022
)
1023
1024
1025
@pytest.mark.skipif(sys.platform == "win32", reason="windows paths are a mess")
1026
@pytest.mark.write_disk
1027
@pytest.mark.parametrize(
1028
"scan_type",
1029
[
1030
(pl.DataFrame.write_parquet, pl.scan_parquet),
1031
(pl.DataFrame.write_ipc, pl.scan_ipc),
1032
(pl.DataFrame.write_csv, pl.scan_csv),
1033
(pl.DataFrame.write_ndjson, pl.scan_ndjson),
1034
],
1035
)
1036
def test_async_read_21945(tmp_path: Path, scan_type: tuple[Any, Any]) -> None:
1037
f1 = tmp_path / "f1"
1038
f2 = tmp_path / "f2"
1039
1040
pl.DataFrame({"value": [1, 2]}).write_parquet(f1)
1041
pl.DataFrame({"value": [3]}).write_parquet(f2)
1042
1043
df = (
1044
pl.scan_parquet(["file://" + str(f1), str(f2)], include_file_paths="foo")
1045
.filter(value=1)
1046
.collect()
1047
)
1048
1049
assert_frame_equal(
1050
df, pl.DataFrame({"value": [1], "foo": ["file://" + f1.as_posix()]})
1051
)
1052
1053
1054
@pytest.mark.write_disk
1055
@pytest.mark.parametrize("with_str_contains", [False, True])
1056
def test_hive_pruning_str_contains_21706(
1057
tmp_path: Path, capfd: Any, monkeypatch: Any, with_str_contains: bool
1058
) -> None:
1059
df = pl.DataFrame(
1060
{
1061
"pdate": [20250301, 20250301, 20250302, 20250302, 20250303, 20250303],
1062
"prod_id": ["A1", "A2", "B1", "B2", "C1", "C2"],
1063
"price": [11, 22, 33, 44, 55, 66],
1064
}
1065
)
1066
1067
df.write_parquet(tmp_path, partition_by=["pdate"])
1068
1069
monkeypatch.setenv("POLARS_VERBOSE", "1")
1070
f = pl.col("pdate") == 20250303
1071
if with_str_contains:
1072
f = f & pl.col("prod_id").str.contains("1")
1073
1074
df = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(f).collect()
1075
1076
captured = capfd.readouterr().err
1077
assert "allows skipping 2 / 3" in captured
1078
1079
assert_frame_equal(
1080
df,
1081
pl.scan_parquet(tmp_path, hive_partitioning=True).collect().filter(f),
1082
)
1083
1084
1085
@pytest.mark.skipif(sys.platform == "win32", reason="paths not valid on Windows")
1086
def test_scan_no_glob_special_chars_23292(tmp_path: Path) -> None:
1087
tmp_path.mkdir(exist_ok=True)
1088
1089
path = tmp_path / "%?.parquet"
1090
df = pl.DataFrame({"a": 1})
1091
df.write_parquet(path)
1092
1093
assert_frame_equal(pl.scan_parquet(f"file://{path}", glob=False).collect(), df)
1094
1095
1096
@pytest.mark.write_disk
1097
@pytest.mark.parametrize(
1098
("scan_function", "failed_message", "name_in_context"),
1099
[
1100
(
1101
pl.scan_parquet,
1102
"failed to retrieve first file schema (parquet)",
1103
"'parquet scan'",
1104
),
1105
(pl.scan_ipc, "failed to retrieve first file schema (ipc)", "'ipc scan'"),
1106
(pl.scan_csv, "failed to retrieve file schemas (csv)", "'csv scan'"),
1107
(
1108
pl.scan_ndjson,
1109
"failed to retrieve first file schema (ndjson)",
1110
"'ndjson scan'",
1111
),
1112
],
1113
)
1114
def test_scan_empty_paths_friendly_error(
1115
tmp_path: Path,
1116
scan_function: Any,
1117
failed_message: str,
1118
name_in_context: str,
1119
) -> None:
1120
q = scan_function(tmp_path)
1121
1122
with pytest.raises(pl.exceptions.ComputeError) as exc:
1123
q.collect()
1124
1125
exc_str = exc.exconly()
1126
1127
assert (
1128
f"ComputeError: {failed_message}: expanded paths were empty "
1129
"(path expansion input: 'paths: [Local"
1130
) in exc_str
1131
1132
assert "glob: true)." in exc_str
1133
assert exc_str.count(tmp_path.name) == 1
1134
1135
assert (
1136
name_in_context
1137
in exc_str.split(
1138
"This error occurred with the following context stack:", maxsplit=1
1139
)[1]
1140
)
1141
1142
if scan_function is pl.scan_parquet:
1143
assert (
1144
"Hint: passing a schema can allow this scan to succeed with an empty DataFrame."
1145
in exc_str
1146
)
1147
1148
# Multiple input paths
1149
q = scan_function([tmp_path, tmp_path])
1150
1151
with pytest.raises(pl.exceptions.ComputeError) as exc:
1152
q.collect()
1153
1154
exc_str = exc.exconly()
1155
1156
assert (
1157
f"ComputeError: {failed_message}: expanded paths were empty "
1158
"(path expansion input: 'paths: [Local"
1159
) in exc_str
1160
1161
assert "glob: true)." in exc_str
1162
1163
assert exc_str.count(tmp_path.name) == 2
1164
1165
q = scan_function([])
1166
1167
with pytest.raises(pl.exceptions.ComputeError) as exc:
1168
q.collect()
1169
1170
exc_str = exc.exconly()
1171
1172
# There is no "path expansion resulted in" for this error message as the
1173
# original input sources were empty.
1174
assert f"ComputeError: {failed_message}: empty input: paths: []" in exc_str
1175
1176
if scan_function is pl.scan_parquet:
1177
assert (
1178
"Hint: passing a schema can allow this scan to succeed with an empty DataFrame."
1179
in exc_str
1180
)
1181
1182
# TODO: glob parameter not supported in some scan types
1183
cx = (
1184
pytest.raises(pl.exceptions.ComputeError, match="glob: false")
1185
if scan_function is pl.scan_csv or scan_function is pl.scan_parquet
1186
else pytest.raises(TypeError, match="unexpected keyword argument 'glob'") # type: ignore[arg-type]
1187
)
1188
1189
with cx:
1190
scan_function(tmp_path, glob=False).collect()
1191
1192