Path: blob/main/py-polars/tests/unit/io/test_lazy_csv.py
6939 views
from __future__ import annotations12import io3import tempfile4from collections import OrderedDict5from pathlib import Path67import numpy as np8import pytest910import polars as pl11from polars.exceptions import ComputeError, ShapeError12from polars.testing import assert_frame_equal131415@pytest.fixture16def foods_file_path(io_files_path: Path) -> Path:17return io_files_path / "foods1.csv"181920def test_scan_csv(io_files_path: Path) -> None:21df = pl.scan_csv(io_files_path / "small.csv")22assert df.collect().shape == (4, 3)232425def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None:26dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1)27pl.concat(dfs, parallel=True).collect(28optimizations=pl.QueryOptFlags(comm_subplan_elim=False)29)303132def test_scan_empty_csv(io_files_path: Path) -> None:33with pytest.raises(Exception) as excinfo:34pl.scan_csv(io_files_path / "empty.csv").collect()35assert "empty CSV" in str(excinfo.value)3637lf = pl.scan_csv(io_files_path / "empty.csv", raise_if_empty=False)38assert_frame_equal(lf, pl.LazyFrame())394041@pytest.mark.write_disk42def test_invalid_utf8(tmp_path: Path) -> None:43tmp_path.mkdir(exist_ok=True)4445np.random.seed(1)46bts = bytes(np.random.randint(0, 255, 200))4748file_path = tmp_path / "nonutf8.csv"49file_path.write_bytes(bts)5051a = pl.read_csv(file_path, has_header=False, encoding="utf8-lossy")52b = pl.scan_csv(file_path, has_header=False, encoding="utf8-lossy").collect()5354assert_frame_equal(a, b)555657def test_row_index(foods_file_path: Path) -> None:58df = pl.read_csv(foods_file_path, row_index_name="row_index")59assert df["row_index"].to_list() == list(range(27))6061df = (62pl.scan_csv(foods_file_path, row_index_name="row_index")63.filter(pl.col("category") == pl.lit("vegetables"))64.collect()65)6667assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25]6869df = (70pl.scan_csv(foods_file_path, row_index_name="row_index")71.with_row_index("foo", 10)72.filter(pl.col("category") == pl.lit("vegetables"))73.collect()74)7576assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]777879@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])80@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV81def test_scan_csv_schema_overwrite_and_dtypes_overwrite(82io_files_path: Path, file_name: str83) -> None:84file_path = io_files_path / file_name85q = pl.scan_csv(86file_path,87schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32},88with_column_names=lambda names: [f"{a}_foo" for a in names],89)9091assert q.collect_schema().dtypes() == [pl.String, pl.String, pl.Float32, pl.Int64]9293df = q.collect()9495assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64]96assert df.columns == [97"category_foo",98"calories_foo",99"fats_g_foo",100"sugars_g_foo",101]102103104@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])105@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16])106@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV107def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite(108io_files_path: Path, file_name: str, dtype: pl.DataType109) -> None:110file_path = io_files_path / file_name111df = pl.scan_csv(112file_path,113schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype},114with_column_names=lambda names: [f"{a}_foo" for a in names],115).collect()116assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype]117assert df.columns == [118"category_foo",119"calories_foo",120"fats_g_foo",121"sugars_g_foo",122]123124125@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])126@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV127def test_scan_csv_schema_new_columns_dtypes(128io_files_path: Path, file_name: str129) -> None:130file_path = io_files_path / file_name131132for dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]:133# assign 'new_columns', providing partial dtype overrides134df1 = pl.scan_csv(135file_path,136schema_overrides={"calories": pl.String, "sugars": dtype},137new_columns=["category", "calories", "fats", "sugars"],138).collect()139assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype]140assert df1.columns == ["category", "calories", "fats", "sugars"]141142# assign 'new_columns' with 'dtypes' list143df2 = pl.scan_csv(144file_path,145schema_overrides=[pl.String, pl.String, pl.Float64, dtype],146new_columns=["category", "calories", "fats", "sugars"],147).collect()148assert df1.rows() == df2.rows()149150# rename existing columns, then lazy-select disjoint cols151lf = pl.scan_csv(152file_path,153new_columns=["colw", "colx", "coly", "colz"],154)155schema = lf.collect_schema()156assert schema.dtypes() == [pl.String, pl.Int64, pl.Float64, pl.Int64]157assert schema.names() == ["colw", "colx", "coly", "colz"]158assert (159lf.select("colz", "colx").collect().rows()160== df1.select("sugars", pl.col("calories").cast(pl.Int64)).rows()161)162163# partially rename columns / overwrite dtypes164df4 = pl.scan_csv(165file_path,166schema_overrides=[pl.String, pl.String],167new_columns=["category", "calories"],168).collect()169assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64]170assert df4.columns == ["category", "calories", "fats_g", "sugars_g"]171172# cannot have len(new_columns) > len(actual columns)173with pytest.raises(ShapeError):174pl.scan_csv(175file_path,176schema_overrides=[pl.String, pl.String],177new_columns=["category", "calories", "c3", "c4", "c5"],178).collect()179180# cannot set both 'new_columns' and 'with_column_names'181with pytest.raises(ValueError, match="mutually.exclusive"):182pl.scan_csv(183file_path,184schema_overrides=[pl.String, pl.String],185new_columns=["category", "calories", "fats", "sugars"],186with_column_names=lambda cols: [col.capitalize() for col in cols],187).collect()188189190def test_lazy_n_rows(foods_file_path: Path) -> None:191df = (192pl.scan_csv(foods_file_path, n_rows=4, row_index_name="idx")193.filter(pl.col("idx") > 2)194.collect()195)196assert df.to_dict(as_series=False) == {197"idx": [3],198"category": ["fruit"],199"calories": [60],200"fats_g": [0.0],201"sugars_g": [11],202}203204205def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None:206plan = (207pl.scan_csv(foods_file_path)208.with_row_index()209.filter(pl.col("index") == 1)210.filter(pl.col("category") == pl.lit("vegetables"))211.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=True))212)213# related to row count is not pushed.214assert 'FILTER [(col("index")) == (1)]\nFROM' in plan215# unrelated to row count is pushed.216assert 'SELECTION: [(col("category")) == ("vegetables")]' in plan217218219@pytest.mark.write_disk220def test_glob_skip_rows(tmp_path: Path) -> None:221tmp_path.mkdir(exist_ok=True)222223for i in range(2):224file_path = tmp_path / f"test_{i}.csv"225file_path.write_text(226f"""227metadata goes here228file number {i}229foo,bar,baz2301,2,32314,5,62327,8,9233"""234)235file_path = tmp_path / "*.csv"236assert pl.read_csv(file_path, skip_rows=2).to_dict(as_series=False) == {237"foo": [1, 4, 7, 1, 4, 7],238"bar": [2, 5, 8, 2, 5, 8],239"baz": [3, 6, 9, 3, 6, 9],240}241242243def test_glob_n_rows(io_files_path: Path) -> None:244file_path = io_files_path / "foods*.csv"245df = pl.scan_csv(file_path, n_rows=40).collect()246247# 27 rows from foods1.csv and 13 from foods2.csv248assert df.shape == (40, 4)249250# take first and last rows251assert df[[0, 39]].to_dict(as_series=False) == {252"category": ["vegetables", "seafood"],253"calories": [45, 146],254"fats_g": [0.5, 6.0],255"sugars_g": [2, 2],256}257258259def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> None:260df = (261pl.scan_csv(262foods_file_path,263schema_overrides={"calories": pl.String, "sugars_g": pl.Int8},264)265.select(pl.len())266.collect()267)268expected = pl.DataFrame({"len": 27}, schema={"len": pl.UInt32})269assert_frame_equal(df, expected)270271272def test_csv_list_arg(io_files_path: Path) -> None:273first = io_files_path / "foods1.csv"274second = io_files_path / "foods2.csv"275276df = pl.scan_csv(source=[first, second]).collect()277assert df.shape == (54, 4)278assert df.row(-1) == ("seafood", 194, 12.0, 1)279assert df.row(0) == ("vegetables", 45, 0.5, 2)280281282# https://github.com/pola-rs/polars/issues/9887283def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None:284lf = pl.scan_csv(io_files_path / "small.csv")285result = lf.slice(0)286assert result.collect().height == 4287288289@pytest.mark.write_disk290def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None:291tmp_path.mkdir(exist_ok=True)292file_path = tmp_path / "small.parquet"293df = pl.DataFrame({"a": []})294df.write_csv(file_path)295296read = pl.scan_csv(file_path).with_row_index("idx")297assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)])298299300@pytest.mark.write_disk301def test_csv_null_values_with_projection_15515() -> None:302data = """IndCode,SireCode,BirthDate,Flag303ID00316,.,19940315,304"""305306with tempfile.NamedTemporaryFile() as f:307f.write(data.encode())308f.seek(0)309310q = (311pl.scan_csv(f.name, null_values={"SireCode": "."})312.with_columns(pl.col("SireCode").alias("SireKey"))313.select("SireKey", "BirthDate")314)315316assert q.collect().to_dict(as_series=False) == {317"SireKey": [None],318"BirthDate": [19940315],319}320321322@pytest.mark.write_disk323def test_csv_respect_user_schema_ragged_lines_15254() -> None:324with tempfile.NamedTemporaryFile() as f:325f.write(326b"""327A,B,C3281,2,33294,5,6,7,83309,10,11331""".strip()332)333f.seek(0)334335df = pl.scan_csv(336f.name, schema=dict.fromkeys("ABCDE", pl.String), truncate_ragged_lines=True337).collect()338assert df.to_dict(as_series=False) == {339"A": ["1", "4", "9"],340"B": ["2", "5", "10"],341"C": ["3", "6", "11"],342"D": [None, "7", None],343"E": [None, "8", None],344}345346347@pytest.mark.parametrize("streaming", [True, False])348@pytest.mark.parametrize(349"dfs",350[351[pl.DataFrame({"a": [1, 2, 3]}), pl.DataFrame({"b": [4, 5, 6]})],352[353pl.DataFrame({"a": [1, 2, 3]}),354pl.DataFrame({"b": [4, 5, 6], "c": [7, 8, 9]}),355],356],357)358@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV359def test_file_list_schema_mismatch(360tmp_path: Path, dfs: list[pl.DataFrame], streaming: bool361) -> None:362tmp_path.mkdir(exist_ok=True)363364paths = [f"{tmp_path}/{i}.csv" for i in range(len(dfs))]365366for df, path in zip(dfs, paths):367df.write_csv(path)368369lf = pl.scan_csv(paths)370with pytest.raises((ComputeError, pl.exceptions.ColumnNotFoundError)):371lf.collect(engine="streaming" if streaming else "in-memory")372373if streaming:374pytest.xfail(reason="missing_columns parameter for CSV")375376if len({df.width for df in dfs}) == 1:377expect = pl.concat(df.select(x=pl.first().cast(pl.Int8)) for df in dfs)378out = pl.scan_csv(paths, schema={"x": pl.Int8}).collect( # type: ignore[call-overload]379engine="streaming" if streaming else "in-memory" # type: ignore[redundant-expr]380)381382assert_frame_equal(out, expect)383384385@pytest.mark.may_fail_auto_streaming386@pytest.mark.parametrize("streaming", [True, False])387def test_file_list_schema_supertype(tmp_path: Path, streaming: bool) -> None:388tmp_path.mkdir(exist_ok=True)389390data_lst = [391"""\392a39313942395""",396"""\397a398b399c400""",401]402403paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]404405for data, path in zip(data_lst, paths):406with Path(path).open("w") as f:407f.write(data)408409expect = pl.Series("a", ["1", "2", "b", "c"]).to_frame()410out = pl.scan_csv(paths).collect(engine="streaming" if streaming else "in-memory")411412assert_frame_equal(out, expect)413414415@pytest.mark.parametrize("streaming", [True, False])416def test_file_list_comment_skip_rows_16327(tmp_path: Path, streaming: bool) -> None:417tmp_path.mkdir(exist_ok=True)418419data_lst = [420"""\421# comment422a423b424c425""",426"""\427a428b429c430""",431]432433paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]434435for data, path in zip(data_lst, paths):436with Path(path).open("w") as f:437f.write(data)438439expect = pl.Series("a", ["b", "c", "b", "c"]).to_frame()440out = pl.scan_csv(paths, comment_prefix="#").collect(441engine="streaming" if streaming else "in-memory"442)443444assert_frame_equal(out, expect)445446447@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17634")448def test_scan_csv_with_column_names_nonexistent_file() -> None:449path_str = "my-nonexistent-data.csv"450path = Path(path_str)451assert not path.exists()452453# Just calling the scan function should not raise any errors454result = pl.scan_csv(path, with_column_names=lambda x: [c.upper() for c in x])455assert isinstance(result, pl.LazyFrame)456457# Upon collection, it should fail458with pytest.raises(FileNotFoundError):459result.collect()460461462def test_select_nonexistent_column() -> None:463csv = "a\n1"464f = io.StringIO(csv)465466with pytest.raises(pl.exceptions.ColumnNotFoundError):467pl.scan_csv(f).select("b").collect()468469470def test_scan_csv_provided_schema_with_extra_fields_22531() -> None:471data = b"""\472a,b,c473a,b,c474"""475476schema = {x: pl.String for x in ["a", "b", "c", "d", "e"]}477478assert_frame_equal(479pl.scan_csv(data, schema=schema).collect(),480pl.DataFrame(481{482"a": "a",483"b": "b",484"c": "c",485"d": None,486"e": None,487},488schema=schema,489),490)491492493def test_csv_negative_slice_comment_char_22996() -> None:494f = b"""\495a,b4961,1497"""498499q = pl.scan_csv(2 * [f], comment_prefix="#").tail(100)500assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 1], "b": [1, 1]}))501502503def test_csv_io_object_utf8_23629() -> None:504n_repeats = 10_000505for df in [506pl.DataFrame({"a": ["é,è"], "b": ["c,d"]}),507pl.DataFrame({"a": ["Ú;и"], "b": ["c;d"]}),508pl.DataFrame({"a": ["a,b"], "b": ["c,d"]}),509pl.DataFrame({"a": ["é," * n_repeats + "è"], "b": ["c," * n_repeats + "d"]}),510]:511# bytes512f_bytes = io.BytesIO()513df.write_csv(f_bytes)514f_bytes.seek(0)515df_bytes = pl.read_csv(f_bytes)516assert_frame_equal(df, df_bytes)517518# str519f_str = io.StringIO()520df.write_csv(f_str)521f_str.seek(0)522df_str = pl.read_csv(f_str)523assert_frame_equal(df, df_str)524525526