Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_lazy_csv.py
8327 views
1
from __future__ import annotations
2
3
import io
4
import tempfile
5
from collections import OrderedDict
6
from pathlib import Path
7
from typing import IO
8
9
import numpy as np
10
import pytest
11
12
import polars as pl
13
from polars.exceptions import ComputeError, ShapeError
14
from polars.testing import assert_frame_equal
15
16
17
@pytest.fixture
18
def foods_file_path(io_files_path: Path) -> Path:
19
return io_files_path / "foods1.csv"
20
21
22
def test_scan_csv(io_files_path: Path) -> None:
23
df = pl.scan_csv(io_files_path / "small.csv")
24
assert df.collect().shape == (4, 3)
25
26
27
def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None:
28
dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1)
29
pl.concat(dfs, parallel=True).collect(
30
optimizations=pl.QueryOptFlags(comm_subplan_elim=False)
31
)
32
33
34
def test_scan_empty_csv(io_files_path: Path) -> None:
35
with pytest.raises(Exception) as excinfo:
36
pl.scan_csv(io_files_path / "empty.csv").collect()
37
assert "empty CSV" in str(excinfo.value)
38
39
lf = pl.scan_csv(io_files_path / "empty.csv", raise_if_empty=False)
40
assert_frame_equal(lf, pl.LazyFrame())
41
42
43
@pytest.mark.write_disk
44
def test_invalid_utf8(tmp_path: Path) -> None:
45
tmp_path.mkdir(exist_ok=True)
46
47
np.random.seed(1)
48
bts = bytes(np.random.randint(0, 255, 200))
49
50
file_path = tmp_path / "nonutf8.csv"
51
file_path.write_bytes(bts)
52
53
a = pl.read_csv(file_path, has_header=False, encoding="utf8-lossy")
54
b = pl.scan_csv(file_path, has_header=False, encoding="utf8-lossy").collect()
55
56
assert_frame_equal(a, b)
57
58
59
def test_row_index(foods_file_path: Path) -> None:
60
df = pl.read_csv(foods_file_path, row_index_name="row_index")
61
assert df["row_index"].to_list() == list(range(27))
62
63
df = (
64
pl.scan_csv(foods_file_path, row_index_name="row_index")
65
.filter(pl.col("category") == pl.lit("vegetables"))
66
.collect()
67
)
68
69
assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25]
70
71
df = (
72
pl.scan_csv(foods_file_path, row_index_name="row_index")
73
.with_row_index("foo", 10)
74
.filter(pl.col("category") == pl.lit("vegetables"))
75
.collect()
76
)
77
78
assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]
79
80
81
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
82
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
83
def test_scan_csv_schema_overwrite_and_dtypes_overwrite(
84
io_files_path: Path, file_name: str
85
) -> None:
86
file_path = io_files_path / file_name
87
q = pl.scan_csv(
88
file_path,
89
schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32},
90
with_column_names=lambda names: [f"{a}_foo" for a in names],
91
)
92
93
assert q.collect_schema().dtypes() == [pl.String, pl.String, pl.Float32, pl.Int64]
94
95
df = q.collect()
96
97
assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64]
98
assert df.columns == [
99
"category_foo",
100
"calories_foo",
101
"fats_g_foo",
102
"sugars_g_foo",
103
]
104
105
106
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
107
@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16])
108
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
109
def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite(
110
io_files_path: Path, file_name: str, dtype: pl.DataType
111
) -> None:
112
file_path = io_files_path / file_name
113
df = pl.scan_csv(
114
file_path,
115
schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype},
116
with_column_names=lambda names: [f"{a}_foo" for a in names],
117
).collect()
118
assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype]
119
assert df.columns == [
120
"category_foo",
121
"calories_foo",
122
"fats_g_foo",
123
"sugars_g_foo",
124
]
125
126
127
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
128
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
129
def test_scan_csv_schema_new_columns_dtypes(
130
io_files_path: Path, file_name: str
131
) -> None:
132
file_path = io_files_path / file_name
133
134
for dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]:
135
# assign 'new_columns', providing partial dtype overrides
136
df1 = pl.scan_csv(
137
file_path,
138
schema_overrides={"calories": pl.String, "sugars": dtype},
139
new_columns=["category", "calories", "fats", "sugars"],
140
).collect()
141
assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype]
142
assert df1.columns == ["category", "calories", "fats", "sugars"]
143
144
# assign 'new_columns' with 'dtypes' list
145
df2 = pl.scan_csv(
146
file_path,
147
schema_overrides=[pl.String, pl.String, pl.Float64, dtype],
148
new_columns=["category", "calories", "fats", "sugars"],
149
).collect()
150
assert df1.rows() == df2.rows()
151
152
# rename existing columns, then lazy-select disjoint cols
153
lf = pl.scan_csv(
154
file_path,
155
new_columns=["colw", "colx", "coly", "colz"],
156
)
157
schema = lf.collect_schema()
158
assert schema.dtypes() == [pl.String, pl.Int64, pl.Float64, pl.Int64]
159
assert schema.names() == ["colw", "colx", "coly", "colz"]
160
assert (
161
lf.select("colz", "colx").collect().rows()
162
== df1.select("sugars", pl.col("calories").cast(pl.Int64)).rows()
163
)
164
165
# partially rename columns / overwrite dtypes
166
df4 = pl.scan_csv(
167
file_path,
168
schema_overrides=[pl.String, pl.String],
169
new_columns=["category", "calories"],
170
).collect()
171
assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64]
172
assert df4.columns == ["category", "calories", "fats_g", "sugars_g"]
173
174
# cannot have len(new_columns) > len(actual columns)
175
with pytest.raises(ShapeError):
176
pl.scan_csv(
177
file_path,
178
schema_overrides=[pl.String, pl.String],
179
new_columns=["category", "calories", "c3", "c4", "c5"],
180
).collect()
181
182
# cannot set both 'new_columns' and 'with_column_names'
183
with pytest.raises(ValueError, match=r"mutually.exclusive"):
184
pl.scan_csv(
185
file_path,
186
schema_overrides=[pl.String, pl.String],
187
new_columns=["category", "calories", "fats", "sugars"],
188
with_column_names=lambda cols: [col.capitalize() for col in cols],
189
).collect()
190
191
192
def test_lazy_n_rows(foods_file_path: Path) -> None:
193
df = (
194
pl.scan_csv(foods_file_path, n_rows=4, row_index_name="idx")
195
.filter(pl.col("idx") > 2)
196
.collect()
197
)
198
assert df.to_dict(as_series=False) == {
199
"idx": [3],
200
"category": ["fruit"],
201
"calories": [60],
202
"fats_g": [0.0],
203
"sugars_g": [11],
204
}
205
206
207
def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None:
208
q = (
209
pl.scan_csv(foods_file_path)
210
.with_row_index()
211
.filter(pl.col("index") > 13)
212
.filter(pl.col("category") == pl.lit("vegetables"))
213
)
214
215
plan = q.explain()
216
217
assert "FILTER" not in plan
218
219
assert_frame_equal(
220
q,
221
pl.LazyFrame(
222
[
223
pl.Series("index", [14, 20, 25], dtype=pl.get_index_type()),
224
pl.Series(
225
"category",
226
["vegetables", "vegetables", "vegetables"],
227
dtype=pl.String,
228
),
229
pl.Series("calories", [25, 25, 30], dtype=pl.Int64),
230
pl.Series("fats_g", [0.0, 0.0, 0.0], dtype=pl.Float64),
231
pl.Series("sugars_g", [4, 3, 5], dtype=pl.Int64),
232
]
233
),
234
)
235
236
237
@pytest.mark.write_disk
238
def test_glob_skip_rows(tmp_path: Path) -> None:
239
tmp_path.mkdir(exist_ok=True)
240
241
for i in range(2):
242
file_path = tmp_path / f"test_{i}.csv"
243
file_path.write_text(
244
f"""
245
metadata goes here
246
file number {i}
247
foo,bar,baz
248
1,2,3
249
4,5,6
250
7,8,9
251
"""
252
)
253
file_path = tmp_path / "*.csv"
254
assert pl.read_csv(file_path, skip_rows=2).to_dict(as_series=False) == {
255
"foo": [1, 4, 7, 1, 4, 7],
256
"bar": [2, 5, 8, 2, 5, 8],
257
"baz": [3, 6, 9, 3, 6, 9],
258
}
259
260
261
def test_glob_n_rows(io_files_path: Path) -> None:
262
file_path = io_files_path / "foods*.csv"
263
df = pl.scan_csv(file_path, n_rows=40).collect()
264
265
# 27 rows from foods1.csv and 13 from foods2.csv
266
assert df.shape == (40, 4)
267
268
# take first and last rows
269
assert df[[0, 39]].to_dict(as_series=False) == {
270
"category": ["vegetables", "seafood"],
271
"calories": [45, 146],
272
"fats_g": [0.5, 6.0],
273
"sugars_g": [2, 2],
274
}
275
276
277
def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> None:
278
df = (
279
pl.scan_csv(
280
foods_file_path,
281
schema_overrides={"calories": pl.String, "sugars_g": pl.Int8},
282
)
283
.select(pl.len())
284
.collect()
285
)
286
expected = pl.DataFrame({"len": 27}, schema={"len": pl.get_index_type()})
287
assert_frame_equal(df, expected)
288
289
290
def test_csv_list_arg(io_files_path: Path) -> None:
291
first = io_files_path / "foods1.csv"
292
second = io_files_path / "foods2.csv"
293
294
df = pl.scan_csv(source=[first, second]).collect()
295
assert df.shape == (54, 4)
296
assert df.row(-1) == ("seafood", 194, 12.0, 1)
297
assert df.row(0) == ("vegetables", 45, 0.5, 2)
298
299
300
# https://github.com/pola-rs/polars/issues/9887
301
def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None:
302
lf = pl.scan_csv(io_files_path / "small.csv")
303
result = lf.slice(0)
304
assert result.collect().height == 4
305
306
307
@pytest.mark.write_disk
308
def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None:
309
tmp_path.mkdir(exist_ok=True)
310
file_path = tmp_path / "small.csv"
311
df = pl.DataFrame({"a": []})
312
df.write_csv(file_path)
313
314
read = pl.scan_csv(file_path).with_row_index("idx")
315
assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)])
316
317
318
@pytest.mark.write_disk
319
def test_csv_null_values_with_projection_15515() -> None:
320
data = """IndCode,SireCode,BirthDate,Flag
321
ID00316,.,19940315,
322
"""
323
324
with tempfile.NamedTemporaryFile() as f:
325
f.write(data.encode())
326
f.seek(0)
327
328
q = (
329
pl.scan_csv(f.name, null_values={"SireCode": "."})
330
.with_columns(pl.col("SireCode").alias("SireKey"))
331
.select("SireKey", "BirthDate")
332
)
333
334
assert q.collect().to_dict(as_series=False) == {
335
"SireKey": [None],
336
"BirthDate": [19940315],
337
}
338
339
340
@pytest.mark.write_disk
341
def test_csv_respect_user_schema_ragged_lines_15254() -> None:
342
with tempfile.NamedTemporaryFile() as f:
343
f.write(
344
b"""
345
A,B,C
346
1,2,3
347
4,5,6,7,8
348
9,10,11
349
""".strip()
350
)
351
f.seek(0)
352
353
df = pl.scan_csv(
354
f.name, schema=dict.fromkeys("ABCDE", pl.String), truncate_ragged_lines=True
355
).collect()
356
assert df.to_dict(as_series=False) == {
357
"A": ["1", "4", "9"],
358
"B": ["2", "5", "10"],
359
"C": ["3", "6", "11"],
360
"D": [None, "7", None],
361
"E": [None, "8", None],
362
}
363
364
365
@pytest.mark.parametrize("streaming", [True, False])
366
@pytest.mark.parametrize(
367
"dfs",
368
[
369
[pl.DataFrame({"a": [1, 2, 3]}), pl.DataFrame({"b": [4, 5, 6]})],
370
[
371
pl.DataFrame({"a": [1, 2, 3]}),
372
pl.DataFrame({"b": [4, 5, 6], "c": [7, 8, 9]}),
373
],
374
],
375
)
376
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
377
def test_file_list_schema_mismatch(
378
tmp_path: Path, dfs: list[pl.DataFrame], streaming: bool
379
) -> None:
380
tmp_path.mkdir(exist_ok=True)
381
382
paths = [f"{tmp_path}/{i}.csv" for i in range(len(dfs))]
383
384
for df, path in zip(dfs, paths, strict=True):
385
df.write_csv(path)
386
387
lf = pl.scan_csv(paths)
388
with pytest.raises((ComputeError, pl.exceptions.ColumnNotFoundError)):
389
lf.collect(engine="streaming" if streaming else "in-memory")
390
391
if streaming:
392
pytest.xfail(reason="missing_columns parameter for CSV")
393
394
if len({df.width for df in dfs}) == 1:
395
expect = pl.concat(df.select(x=pl.first().cast(pl.Int8)) for df in dfs)
396
out = pl.scan_csv(paths, schema={"x": pl.Int8}).collect( # type: ignore[call-overload]
397
engine="streaming" if streaming else "in-memory" # type: ignore[redundant-expr]
398
)
399
400
assert_frame_equal(out, expect)
401
402
403
@pytest.mark.may_fail_auto_streaming
404
@pytest.mark.parametrize("streaming", [True, False])
405
def test_file_list_schema_supertype(tmp_path: Path, streaming: bool) -> None:
406
tmp_path.mkdir(exist_ok=True)
407
408
data_lst = [
409
"""\
410
a
411
1
412
2
413
""",
414
"""\
415
a
416
b
417
c
418
""",
419
]
420
421
paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]
422
423
for data, path in zip(data_lst, paths, strict=True):
424
with Path(path).open("w") as f:
425
f.write(data)
426
427
expect = pl.Series("a", ["1", "2", "b", "c"]).to_frame()
428
out = pl.scan_csv(paths).collect(engine="streaming" if streaming else "in-memory")
429
430
assert_frame_equal(out, expect)
431
432
433
@pytest.mark.parametrize("streaming", [True, False])
434
def test_file_list_comment_skip_rows_16327(tmp_path: Path, streaming: bool) -> None:
435
tmp_path.mkdir(exist_ok=True)
436
437
data_lst = [
438
"""\
439
# comment
440
a
441
b
442
c
443
""",
444
"""\
445
a
446
b
447
c
448
""",
449
]
450
451
paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]
452
453
for data, path in zip(data_lst, paths, strict=True):
454
with Path(path).open("w") as f:
455
f.write(data)
456
457
expect = pl.Series("a", ["b", "c", "b", "c"]).to_frame()
458
out = pl.scan_csv(paths, comment_prefix="#").collect(
459
engine="streaming" if streaming else "in-memory"
460
)
461
462
assert_frame_equal(out, expect)
463
464
465
@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17634")
466
def test_scan_csv_with_column_names_nonexistent_file() -> None:
467
path_str = "my-nonexistent-data.csv"
468
path = Path(path_str)
469
assert not path.exists()
470
471
# Just calling the scan function should not raise any errors
472
result = pl.scan_csv(path, with_column_names=lambda x: [c.upper() for c in x])
473
assert isinstance(result, pl.LazyFrame)
474
475
# Upon collection, it should fail
476
with pytest.raises(FileNotFoundError):
477
result.collect()
478
479
480
def test_select_nonexistent_column() -> None:
481
csv = "a\n1"
482
f = io.StringIO(csv)
483
484
with pytest.raises(pl.exceptions.ColumnNotFoundError):
485
pl.scan_csv(f).select("b").collect()
486
487
488
def test_scan_csv_provided_schema_with_extra_fields_22531() -> None:
489
data = b"""\
490
a,b,c
491
a,b,c
492
"""
493
494
schema = dict.fromkeys(["a", "b", "c", "d", "e"], pl.String)
495
496
assert_frame_equal(
497
pl.scan_csv(data, schema=schema).collect(),
498
pl.DataFrame(
499
{
500
"a": "a",
501
"b": "b",
502
"c": "c",
503
"d": None,
504
"e": None,
505
},
506
schema=schema,
507
),
508
)
509
510
511
def test_csv_negative_slice_comment_char_22996() -> None:
512
f = b"""\
513
a,b
514
1,1
515
"""
516
517
q = pl.scan_csv(2 * [f], comment_prefix="#").tail(100)
518
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 1], "b": [1, 1]}))
519
520
521
def test_csv_io_object_utf8_23629() -> None:
522
n_repeats = 10_000
523
for df in [
524
pl.DataFrame({"a": ["é,è"], "b": ["c,d"]}),
525
pl.DataFrame({"a": ["Ú;и"], "b": ["c;d"]}),
526
pl.DataFrame({"a": ["a,b"], "b": ["c,d"]}),
527
pl.DataFrame({"a": ["é," * n_repeats + "è"], "b": ["c," * n_repeats + "d"]}),
528
]:
529
# bytes
530
f_bytes = io.BytesIO()
531
df.write_csv(f_bytes)
532
f_bytes.seek(0)
533
df_bytes = pl.read_csv(f_bytes)
534
assert_frame_equal(df, df_bytes)
535
536
# str
537
f_str = io.StringIO()
538
df.write_csv(f_str)
539
f_str.seek(0)
540
df_str = pl.read_csv(f_str)
541
assert_frame_equal(df, df_str)
542
543
544
def test_scan_csv_multiple_files_skip_rows_overflow_26127() -> None:
545
files: list[IO[bytes]] = [
546
io.BytesIO(b"foo,bar,baz\n1,2,3\n4,5,6") for _ in range(2)
547
]
548
assert_frame_equal(
549
pl.scan_csv(
550
files,
551
n_rows=4,
552
skip_rows=2,
553
).collect(),
554
pl.DataFrame(schema={"4": pl.String, "5": pl.String, "6": pl.String}),
555
)
556
557