Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_lazy_csv.py
6939 views
1
from __future__ import annotations
2
3
import io
4
import tempfile
5
from collections import OrderedDict
6
from pathlib import Path
7
8
import numpy as np
9
import pytest
10
11
import polars as pl
12
from polars.exceptions import ComputeError, ShapeError
13
from polars.testing import assert_frame_equal
14
15
16
@pytest.fixture
17
def foods_file_path(io_files_path: Path) -> Path:
18
return io_files_path / "foods1.csv"
19
20
21
def test_scan_csv(io_files_path: Path) -> None:
22
df = pl.scan_csv(io_files_path / "small.csv")
23
assert df.collect().shape == (4, 3)
24
25
26
def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None:
27
dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1)
28
pl.concat(dfs, parallel=True).collect(
29
optimizations=pl.QueryOptFlags(comm_subplan_elim=False)
30
)
31
32
33
def test_scan_empty_csv(io_files_path: Path) -> None:
34
with pytest.raises(Exception) as excinfo:
35
pl.scan_csv(io_files_path / "empty.csv").collect()
36
assert "empty CSV" in str(excinfo.value)
37
38
lf = pl.scan_csv(io_files_path / "empty.csv", raise_if_empty=False)
39
assert_frame_equal(lf, pl.LazyFrame())
40
41
42
@pytest.mark.write_disk
43
def test_invalid_utf8(tmp_path: Path) -> None:
44
tmp_path.mkdir(exist_ok=True)
45
46
np.random.seed(1)
47
bts = bytes(np.random.randint(0, 255, 200))
48
49
file_path = tmp_path / "nonutf8.csv"
50
file_path.write_bytes(bts)
51
52
a = pl.read_csv(file_path, has_header=False, encoding="utf8-lossy")
53
b = pl.scan_csv(file_path, has_header=False, encoding="utf8-lossy").collect()
54
55
assert_frame_equal(a, b)
56
57
58
def test_row_index(foods_file_path: Path) -> None:
59
df = pl.read_csv(foods_file_path, row_index_name="row_index")
60
assert df["row_index"].to_list() == list(range(27))
61
62
df = (
63
pl.scan_csv(foods_file_path, row_index_name="row_index")
64
.filter(pl.col("category") == pl.lit("vegetables"))
65
.collect()
66
)
67
68
assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25]
69
70
df = (
71
pl.scan_csv(foods_file_path, row_index_name="row_index")
72
.with_row_index("foo", 10)
73
.filter(pl.col("category") == pl.lit("vegetables"))
74
.collect()
75
)
76
77
assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]
78
79
80
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
81
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
82
def test_scan_csv_schema_overwrite_and_dtypes_overwrite(
83
io_files_path: Path, file_name: str
84
) -> None:
85
file_path = io_files_path / file_name
86
q = pl.scan_csv(
87
file_path,
88
schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32},
89
with_column_names=lambda names: [f"{a}_foo" for a in names],
90
)
91
92
assert q.collect_schema().dtypes() == [pl.String, pl.String, pl.Float32, pl.Int64]
93
94
df = q.collect()
95
96
assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64]
97
assert df.columns == [
98
"category_foo",
99
"calories_foo",
100
"fats_g_foo",
101
"sugars_g_foo",
102
]
103
104
105
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
106
@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16])
107
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
108
def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite(
109
io_files_path: Path, file_name: str, dtype: pl.DataType
110
) -> None:
111
file_path = io_files_path / file_name
112
df = pl.scan_csv(
113
file_path,
114
schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype},
115
with_column_names=lambda names: [f"{a}_foo" for a in names],
116
).collect()
117
assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype]
118
assert df.columns == [
119
"category_foo",
120
"calories_foo",
121
"fats_g_foo",
122
"sugars_g_foo",
123
]
124
125
126
@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])
127
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
128
def test_scan_csv_schema_new_columns_dtypes(
129
io_files_path: Path, file_name: str
130
) -> None:
131
file_path = io_files_path / file_name
132
133
for dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]:
134
# assign 'new_columns', providing partial dtype overrides
135
df1 = pl.scan_csv(
136
file_path,
137
schema_overrides={"calories": pl.String, "sugars": dtype},
138
new_columns=["category", "calories", "fats", "sugars"],
139
).collect()
140
assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype]
141
assert df1.columns == ["category", "calories", "fats", "sugars"]
142
143
# assign 'new_columns' with 'dtypes' list
144
df2 = pl.scan_csv(
145
file_path,
146
schema_overrides=[pl.String, pl.String, pl.Float64, dtype],
147
new_columns=["category", "calories", "fats", "sugars"],
148
).collect()
149
assert df1.rows() == df2.rows()
150
151
# rename existing columns, then lazy-select disjoint cols
152
lf = pl.scan_csv(
153
file_path,
154
new_columns=["colw", "colx", "coly", "colz"],
155
)
156
schema = lf.collect_schema()
157
assert schema.dtypes() == [pl.String, pl.Int64, pl.Float64, pl.Int64]
158
assert schema.names() == ["colw", "colx", "coly", "colz"]
159
assert (
160
lf.select("colz", "colx").collect().rows()
161
== df1.select("sugars", pl.col("calories").cast(pl.Int64)).rows()
162
)
163
164
# partially rename columns / overwrite dtypes
165
df4 = pl.scan_csv(
166
file_path,
167
schema_overrides=[pl.String, pl.String],
168
new_columns=["category", "calories"],
169
).collect()
170
assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64]
171
assert df4.columns == ["category", "calories", "fats_g", "sugars_g"]
172
173
# cannot have len(new_columns) > len(actual columns)
174
with pytest.raises(ShapeError):
175
pl.scan_csv(
176
file_path,
177
schema_overrides=[pl.String, pl.String],
178
new_columns=["category", "calories", "c3", "c4", "c5"],
179
).collect()
180
181
# cannot set both 'new_columns' and 'with_column_names'
182
with pytest.raises(ValueError, match="mutually.exclusive"):
183
pl.scan_csv(
184
file_path,
185
schema_overrides=[pl.String, pl.String],
186
new_columns=["category", "calories", "fats", "sugars"],
187
with_column_names=lambda cols: [col.capitalize() for col in cols],
188
).collect()
189
190
191
def test_lazy_n_rows(foods_file_path: Path) -> None:
192
df = (
193
pl.scan_csv(foods_file_path, n_rows=4, row_index_name="idx")
194
.filter(pl.col("idx") > 2)
195
.collect()
196
)
197
assert df.to_dict(as_series=False) == {
198
"idx": [3],
199
"category": ["fruit"],
200
"calories": [60],
201
"fats_g": [0.0],
202
"sugars_g": [11],
203
}
204
205
206
def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None:
207
plan = (
208
pl.scan_csv(foods_file_path)
209
.with_row_index()
210
.filter(pl.col("index") == 1)
211
.filter(pl.col("category") == pl.lit("vegetables"))
212
.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
213
)
214
# related to row count is not pushed.
215
assert 'FILTER [(col("index")) == (1)]\nFROM' in plan
216
# unrelated to row count is pushed.
217
assert 'SELECTION: [(col("category")) == ("vegetables")]' in plan
218
219
220
@pytest.mark.write_disk
221
def test_glob_skip_rows(tmp_path: Path) -> None:
222
tmp_path.mkdir(exist_ok=True)
223
224
for i in range(2):
225
file_path = tmp_path / f"test_{i}.csv"
226
file_path.write_text(
227
f"""
228
metadata goes here
229
file number {i}
230
foo,bar,baz
231
1,2,3
232
4,5,6
233
7,8,9
234
"""
235
)
236
file_path = tmp_path / "*.csv"
237
assert pl.read_csv(file_path, skip_rows=2).to_dict(as_series=False) == {
238
"foo": [1, 4, 7, 1, 4, 7],
239
"bar": [2, 5, 8, 2, 5, 8],
240
"baz": [3, 6, 9, 3, 6, 9],
241
}
242
243
244
def test_glob_n_rows(io_files_path: Path) -> None:
245
file_path = io_files_path / "foods*.csv"
246
df = pl.scan_csv(file_path, n_rows=40).collect()
247
248
# 27 rows from foods1.csv and 13 from foods2.csv
249
assert df.shape == (40, 4)
250
251
# take first and last rows
252
assert df[[0, 39]].to_dict(as_series=False) == {
253
"category": ["vegetables", "seafood"],
254
"calories": [45, 146],
255
"fats_g": [0.5, 6.0],
256
"sugars_g": [2, 2],
257
}
258
259
260
def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> None:
261
df = (
262
pl.scan_csv(
263
foods_file_path,
264
schema_overrides={"calories": pl.String, "sugars_g": pl.Int8},
265
)
266
.select(pl.len())
267
.collect()
268
)
269
expected = pl.DataFrame({"len": 27}, schema={"len": pl.UInt32})
270
assert_frame_equal(df, expected)
271
272
273
def test_csv_list_arg(io_files_path: Path) -> None:
274
first = io_files_path / "foods1.csv"
275
second = io_files_path / "foods2.csv"
276
277
df = pl.scan_csv(source=[first, second]).collect()
278
assert df.shape == (54, 4)
279
assert df.row(-1) == ("seafood", 194, 12.0, 1)
280
assert df.row(0) == ("vegetables", 45, 0.5, 2)
281
282
283
# https://github.com/pola-rs/polars/issues/9887
284
def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None:
285
lf = pl.scan_csv(io_files_path / "small.csv")
286
result = lf.slice(0)
287
assert result.collect().height == 4
288
289
290
@pytest.mark.write_disk
291
def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None:
292
tmp_path.mkdir(exist_ok=True)
293
file_path = tmp_path / "small.parquet"
294
df = pl.DataFrame({"a": []})
295
df.write_csv(file_path)
296
297
read = pl.scan_csv(file_path).with_row_index("idx")
298
assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)])
299
300
301
@pytest.mark.write_disk
302
def test_csv_null_values_with_projection_15515() -> None:
303
data = """IndCode,SireCode,BirthDate,Flag
304
ID00316,.,19940315,
305
"""
306
307
with tempfile.NamedTemporaryFile() as f:
308
f.write(data.encode())
309
f.seek(0)
310
311
q = (
312
pl.scan_csv(f.name, null_values={"SireCode": "."})
313
.with_columns(pl.col("SireCode").alias("SireKey"))
314
.select("SireKey", "BirthDate")
315
)
316
317
assert q.collect().to_dict(as_series=False) == {
318
"SireKey": [None],
319
"BirthDate": [19940315],
320
}
321
322
323
@pytest.mark.write_disk
324
def test_csv_respect_user_schema_ragged_lines_15254() -> None:
325
with tempfile.NamedTemporaryFile() as f:
326
f.write(
327
b"""
328
A,B,C
329
1,2,3
330
4,5,6,7,8
331
9,10,11
332
""".strip()
333
)
334
f.seek(0)
335
336
df = pl.scan_csv(
337
f.name, schema=dict.fromkeys("ABCDE", pl.String), truncate_ragged_lines=True
338
).collect()
339
assert df.to_dict(as_series=False) == {
340
"A": ["1", "4", "9"],
341
"B": ["2", "5", "10"],
342
"C": ["3", "6", "11"],
343
"D": [None, "7", None],
344
"E": [None, "8", None],
345
}
346
347
348
@pytest.mark.parametrize("streaming", [True, False])
349
@pytest.mark.parametrize(
350
"dfs",
351
[
352
[pl.DataFrame({"a": [1, 2, 3]}), pl.DataFrame({"b": [4, 5, 6]})],
353
[
354
pl.DataFrame({"a": [1, 2, 3]}),
355
pl.DataFrame({"b": [4, 5, 6], "c": [7, 8, 9]}),
356
],
357
],
358
)
359
@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV
360
def test_file_list_schema_mismatch(
361
tmp_path: Path, dfs: list[pl.DataFrame], streaming: bool
362
) -> None:
363
tmp_path.mkdir(exist_ok=True)
364
365
paths = [f"{tmp_path}/{i}.csv" for i in range(len(dfs))]
366
367
for df, path in zip(dfs, paths):
368
df.write_csv(path)
369
370
lf = pl.scan_csv(paths)
371
with pytest.raises((ComputeError, pl.exceptions.ColumnNotFoundError)):
372
lf.collect(engine="streaming" if streaming else "in-memory")
373
374
if streaming:
375
pytest.xfail(reason="missing_columns parameter for CSV")
376
377
if len({df.width for df in dfs}) == 1:
378
expect = pl.concat(df.select(x=pl.first().cast(pl.Int8)) for df in dfs)
379
out = pl.scan_csv(paths, schema={"x": pl.Int8}).collect( # type: ignore[call-overload]
380
engine="streaming" if streaming else "in-memory" # type: ignore[redundant-expr]
381
)
382
383
assert_frame_equal(out, expect)
384
385
386
@pytest.mark.may_fail_auto_streaming
387
@pytest.mark.parametrize("streaming", [True, False])
388
def test_file_list_schema_supertype(tmp_path: Path, streaming: bool) -> None:
389
tmp_path.mkdir(exist_ok=True)
390
391
data_lst = [
392
"""\
393
a
394
1
395
2
396
""",
397
"""\
398
a
399
b
400
c
401
""",
402
]
403
404
paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]
405
406
for data, path in zip(data_lst, paths):
407
with Path(path).open("w") as f:
408
f.write(data)
409
410
expect = pl.Series("a", ["1", "2", "b", "c"]).to_frame()
411
out = pl.scan_csv(paths).collect(engine="streaming" if streaming else "in-memory")
412
413
assert_frame_equal(out, expect)
414
415
416
@pytest.mark.parametrize("streaming", [True, False])
417
def test_file_list_comment_skip_rows_16327(tmp_path: Path, streaming: bool) -> None:
418
tmp_path.mkdir(exist_ok=True)
419
420
data_lst = [
421
"""\
422
# comment
423
a
424
b
425
c
426
""",
427
"""\
428
a
429
b
430
c
431
""",
432
]
433
434
paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]
435
436
for data, path in zip(data_lst, paths):
437
with Path(path).open("w") as f:
438
f.write(data)
439
440
expect = pl.Series("a", ["b", "c", "b", "c"]).to_frame()
441
out = pl.scan_csv(paths, comment_prefix="#").collect(
442
engine="streaming" if streaming else "in-memory"
443
)
444
445
assert_frame_equal(out, expect)
446
447
448
@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17634")
449
def test_scan_csv_with_column_names_nonexistent_file() -> None:
450
path_str = "my-nonexistent-data.csv"
451
path = Path(path_str)
452
assert not path.exists()
453
454
# Just calling the scan function should not raise any errors
455
result = pl.scan_csv(path, with_column_names=lambda x: [c.upper() for c in x])
456
assert isinstance(result, pl.LazyFrame)
457
458
# Upon collection, it should fail
459
with pytest.raises(FileNotFoundError):
460
result.collect()
461
462
463
def test_select_nonexistent_column() -> None:
464
csv = "a\n1"
465
f = io.StringIO(csv)
466
467
with pytest.raises(pl.exceptions.ColumnNotFoundError):
468
pl.scan_csv(f).select("b").collect()
469
470
471
def test_scan_csv_provided_schema_with_extra_fields_22531() -> None:
472
data = b"""\
473
a,b,c
474
a,b,c
475
"""
476
477
schema = {x: pl.String for x in ["a", "b", "c", "d", "e"]}
478
479
assert_frame_equal(
480
pl.scan_csv(data, schema=schema).collect(),
481
pl.DataFrame(
482
{
483
"a": "a",
484
"b": "b",
485
"c": "c",
486
"d": None,
487
"e": None,
488
},
489
schema=schema,
490
),
491
)
492
493
494
def test_csv_negative_slice_comment_char_22996() -> None:
495
f = b"""\
496
a,b
497
1,1
498
"""
499
500
q = pl.scan_csv(2 * [f], comment_prefix="#").tail(100)
501
assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 1], "b": [1, 1]}))
502
503
504
def test_csv_io_object_utf8_23629() -> None:
505
n_repeats = 10_000
506
for df in [
507
pl.DataFrame({"a": ["é,è"], "b": ["c,d"]}),
508
pl.DataFrame({"a": ["Ú;и"], "b": ["c;d"]}),
509
pl.DataFrame({"a": ["a,b"], "b": ["c,d"]}),
510
pl.DataFrame({"a": ["é," * n_repeats + "è"], "b": ["c," * n_repeats + "d"]}),
511
]:
512
# bytes
513
f_bytes = io.BytesIO()
514
df.write_csv(f_bytes)
515
f_bytes.seek(0)
516
df_bytes = pl.read_csv(f_bytes)
517
assert_frame_equal(df, df_bytes)
518
519
# str
520
f_str = io.StringIO()
521
df.write_csv(f_str)
522
f_str.seek(0)
523
df_str = pl.read_csv(f_str)
524
assert_frame_equal(df, df_str)
525
526