Path: blob/main/py-polars/tests/unit/io/test_lazy_csv.py
8327 views
from __future__ import annotations12import io3import tempfile4from collections import OrderedDict5from pathlib import Path6from typing import IO78import numpy as np9import pytest1011import polars as pl12from polars.exceptions import ComputeError, ShapeError13from polars.testing import assert_frame_equal141516@pytest.fixture17def foods_file_path(io_files_path: Path) -> Path:18return io_files_path / "foods1.csv"192021def test_scan_csv(io_files_path: Path) -> None:22df = pl.scan_csv(io_files_path / "small.csv")23assert df.collect().shape == (4, 3)242526def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None:27dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1)28pl.concat(dfs, parallel=True).collect(29optimizations=pl.QueryOptFlags(comm_subplan_elim=False)30)313233def test_scan_empty_csv(io_files_path: Path) -> None:34with pytest.raises(Exception) as excinfo:35pl.scan_csv(io_files_path / "empty.csv").collect()36assert "empty CSV" in str(excinfo.value)3738lf = pl.scan_csv(io_files_path / "empty.csv", raise_if_empty=False)39assert_frame_equal(lf, pl.LazyFrame())404142@pytest.mark.write_disk43def test_invalid_utf8(tmp_path: Path) -> None:44tmp_path.mkdir(exist_ok=True)4546np.random.seed(1)47bts = bytes(np.random.randint(0, 255, 200))4849file_path = tmp_path / "nonutf8.csv"50file_path.write_bytes(bts)5152a = pl.read_csv(file_path, has_header=False, encoding="utf8-lossy")53b = pl.scan_csv(file_path, has_header=False, encoding="utf8-lossy").collect()5455assert_frame_equal(a, b)565758def test_row_index(foods_file_path: Path) -> None:59df = pl.read_csv(foods_file_path, row_index_name="row_index")60assert df["row_index"].to_list() == list(range(27))6162df = (63pl.scan_csv(foods_file_path, row_index_name="row_index")64.filter(pl.col("category") == pl.lit("vegetables"))65.collect()66)6768assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25]6970df = (71pl.scan_csv(foods_file_path, row_index_name="row_index")72.with_row_index("foo", 10)73.filter(pl.col("category") == pl.lit("vegetables"))74.collect()75)7677assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]787980@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])81@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV82def test_scan_csv_schema_overwrite_and_dtypes_overwrite(83io_files_path: Path, file_name: str84) -> None:85file_path = io_files_path / file_name86q = pl.scan_csv(87file_path,88schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32},89with_column_names=lambda names: [f"{a}_foo" for a in names],90)9192assert q.collect_schema().dtypes() == [pl.String, pl.String, pl.Float32, pl.Int64]9394df = q.collect()9596assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64]97assert df.columns == [98"category_foo",99"calories_foo",100"fats_g_foo",101"sugars_g_foo",102]103104105@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])106@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16])107@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV108def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite(109io_files_path: Path, file_name: str, dtype: pl.DataType110) -> None:111file_path = io_files_path / file_name112df = pl.scan_csv(113file_path,114schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype},115with_column_names=lambda names: [f"{a}_foo" for a in names],116).collect()117assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype]118assert df.columns == [119"category_foo",120"calories_foo",121"fats_g_foo",122"sugars_g_foo",123]124125126@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"])127@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV128def test_scan_csv_schema_new_columns_dtypes(129io_files_path: Path, file_name: str130) -> None:131file_path = io_files_path / file_name132133for dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]:134# assign 'new_columns', providing partial dtype overrides135df1 = pl.scan_csv(136file_path,137schema_overrides={"calories": pl.String, "sugars": dtype},138new_columns=["category", "calories", "fats", "sugars"],139).collect()140assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype]141assert df1.columns == ["category", "calories", "fats", "sugars"]142143# assign 'new_columns' with 'dtypes' list144df2 = pl.scan_csv(145file_path,146schema_overrides=[pl.String, pl.String, pl.Float64, dtype],147new_columns=["category", "calories", "fats", "sugars"],148).collect()149assert df1.rows() == df2.rows()150151# rename existing columns, then lazy-select disjoint cols152lf = pl.scan_csv(153file_path,154new_columns=["colw", "colx", "coly", "colz"],155)156schema = lf.collect_schema()157assert schema.dtypes() == [pl.String, pl.Int64, pl.Float64, pl.Int64]158assert schema.names() == ["colw", "colx", "coly", "colz"]159assert (160lf.select("colz", "colx").collect().rows()161== df1.select("sugars", pl.col("calories").cast(pl.Int64)).rows()162)163164# partially rename columns / overwrite dtypes165df4 = pl.scan_csv(166file_path,167schema_overrides=[pl.String, pl.String],168new_columns=["category", "calories"],169).collect()170assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64]171assert df4.columns == ["category", "calories", "fats_g", "sugars_g"]172173# cannot have len(new_columns) > len(actual columns)174with pytest.raises(ShapeError):175pl.scan_csv(176file_path,177schema_overrides=[pl.String, pl.String],178new_columns=["category", "calories", "c3", "c4", "c5"],179).collect()180181# cannot set both 'new_columns' and 'with_column_names'182with pytest.raises(ValueError, match=r"mutually.exclusive"):183pl.scan_csv(184file_path,185schema_overrides=[pl.String, pl.String],186new_columns=["category", "calories", "fats", "sugars"],187with_column_names=lambda cols: [col.capitalize() for col in cols],188).collect()189190191def test_lazy_n_rows(foods_file_path: Path) -> None:192df = (193pl.scan_csv(foods_file_path, n_rows=4, row_index_name="idx")194.filter(pl.col("idx") > 2)195.collect()196)197assert df.to_dict(as_series=False) == {198"idx": [3],199"category": ["fruit"],200"calories": [60],201"fats_g": [0.0],202"sugars_g": [11],203}204205206def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None:207q = (208pl.scan_csv(foods_file_path)209.with_row_index()210.filter(pl.col("index") > 13)211.filter(pl.col("category") == pl.lit("vegetables"))212)213214plan = q.explain()215216assert "FILTER" not in plan217218assert_frame_equal(219q,220pl.LazyFrame(221[222pl.Series("index", [14, 20, 25], dtype=pl.get_index_type()),223pl.Series(224"category",225["vegetables", "vegetables", "vegetables"],226dtype=pl.String,227),228pl.Series("calories", [25, 25, 30], dtype=pl.Int64),229pl.Series("fats_g", [0.0, 0.0, 0.0], dtype=pl.Float64),230pl.Series("sugars_g", [4, 3, 5], dtype=pl.Int64),231]232),233)234235236@pytest.mark.write_disk237def test_glob_skip_rows(tmp_path: Path) -> None:238tmp_path.mkdir(exist_ok=True)239240for i in range(2):241file_path = tmp_path / f"test_{i}.csv"242file_path.write_text(243f"""244metadata goes here245file number {i}246foo,bar,baz2471,2,32484,5,62497,8,9250"""251)252file_path = tmp_path / "*.csv"253assert pl.read_csv(file_path, skip_rows=2).to_dict(as_series=False) == {254"foo": [1, 4, 7, 1, 4, 7],255"bar": [2, 5, 8, 2, 5, 8],256"baz": [3, 6, 9, 3, 6, 9],257}258259260def test_glob_n_rows(io_files_path: Path) -> None:261file_path = io_files_path / "foods*.csv"262df = pl.scan_csv(file_path, n_rows=40).collect()263264# 27 rows from foods1.csv and 13 from foods2.csv265assert df.shape == (40, 4)266267# take first and last rows268assert df[[0, 39]].to_dict(as_series=False) == {269"category": ["vegetables", "seafood"],270"calories": [45, 146],271"fats_g": [0.5, 6.0],272"sugars_g": [2, 2],273}274275276def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> None:277df = (278pl.scan_csv(279foods_file_path,280schema_overrides={"calories": pl.String, "sugars_g": pl.Int8},281)282.select(pl.len())283.collect()284)285expected = pl.DataFrame({"len": 27}, schema={"len": pl.get_index_type()})286assert_frame_equal(df, expected)287288289def test_csv_list_arg(io_files_path: Path) -> None:290first = io_files_path / "foods1.csv"291second = io_files_path / "foods2.csv"292293df = pl.scan_csv(source=[first, second]).collect()294assert df.shape == (54, 4)295assert df.row(-1) == ("seafood", 194, 12.0, 1)296assert df.row(0) == ("vegetables", 45, 0.5, 2)297298299# https://github.com/pola-rs/polars/issues/9887300def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None:301lf = pl.scan_csv(io_files_path / "small.csv")302result = lf.slice(0)303assert result.collect().height == 4304305306@pytest.mark.write_disk307def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None:308tmp_path.mkdir(exist_ok=True)309file_path = tmp_path / "small.csv"310df = pl.DataFrame({"a": []})311df.write_csv(file_path)312313read = pl.scan_csv(file_path).with_row_index("idx")314assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)])315316317@pytest.mark.write_disk318def test_csv_null_values_with_projection_15515() -> None:319data = """IndCode,SireCode,BirthDate,Flag320ID00316,.,19940315,321"""322323with tempfile.NamedTemporaryFile() as f:324f.write(data.encode())325f.seek(0)326327q = (328pl.scan_csv(f.name, null_values={"SireCode": "."})329.with_columns(pl.col("SireCode").alias("SireKey"))330.select("SireKey", "BirthDate")331)332333assert q.collect().to_dict(as_series=False) == {334"SireKey": [None],335"BirthDate": [19940315],336}337338339@pytest.mark.write_disk340def test_csv_respect_user_schema_ragged_lines_15254() -> None:341with tempfile.NamedTemporaryFile() as f:342f.write(343b"""344A,B,C3451,2,33464,5,6,7,83479,10,11348""".strip()349)350f.seek(0)351352df = pl.scan_csv(353f.name, schema=dict.fromkeys("ABCDE", pl.String), truncate_ragged_lines=True354).collect()355assert df.to_dict(as_series=False) == {356"A": ["1", "4", "9"],357"B": ["2", "5", "10"],358"C": ["3", "6", "11"],359"D": [None, "7", None],360"E": [None, "8", None],361}362363364@pytest.mark.parametrize("streaming", [True, False])365@pytest.mark.parametrize(366"dfs",367[368[pl.DataFrame({"a": [1, 2, 3]}), pl.DataFrame({"b": [4, 5, 6]})],369[370pl.DataFrame({"a": [1, 2, 3]}),371pl.DataFrame({"b": [4, 5, 6], "c": [7, 8, 9]}),372],373],374)375@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV376def test_file_list_schema_mismatch(377tmp_path: Path, dfs: list[pl.DataFrame], streaming: bool378) -> None:379tmp_path.mkdir(exist_ok=True)380381paths = [f"{tmp_path}/{i}.csv" for i in range(len(dfs))]382383for df, path in zip(dfs, paths, strict=True):384df.write_csv(path)385386lf = pl.scan_csv(paths)387with pytest.raises((ComputeError, pl.exceptions.ColumnNotFoundError)):388lf.collect(engine="streaming" if streaming else "in-memory")389390if streaming:391pytest.xfail(reason="missing_columns parameter for CSV")392393if len({df.width for df in dfs}) == 1:394expect = pl.concat(df.select(x=pl.first().cast(pl.Int8)) for df in dfs)395out = pl.scan_csv(paths, schema={"x": pl.Int8}).collect( # type: ignore[call-overload]396engine="streaming" if streaming else "in-memory" # type: ignore[redundant-expr]397)398399assert_frame_equal(out, expect)400401402@pytest.mark.may_fail_auto_streaming403@pytest.mark.parametrize("streaming", [True, False])404def test_file_list_schema_supertype(tmp_path: Path, streaming: bool) -> None:405tmp_path.mkdir(exist_ok=True)406407data_lst = [408"""\409a41014112412""",413"""\414a415b416c417""",418]419420paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]421422for data, path in zip(data_lst, paths, strict=True):423with Path(path).open("w") as f:424f.write(data)425426expect = pl.Series("a", ["1", "2", "b", "c"]).to_frame()427out = pl.scan_csv(paths).collect(engine="streaming" if streaming else "in-memory")428429assert_frame_equal(out, expect)430431432@pytest.mark.parametrize("streaming", [True, False])433def test_file_list_comment_skip_rows_16327(tmp_path: Path, streaming: bool) -> None:434tmp_path.mkdir(exist_ok=True)435436data_lst = [437"""\438# comment439a440b441c442""",443"""\444a445b446c447""",448]449450paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))]451452for data, path in zip(data_lst, paths, strict=True):453with Path(path).open("w") as f:454f.write(data)455456expect = pl.Series("a", ["b", "c", "b", "c"]).to_frame()457out = pl.scan_csv(paths, comment_prefix="#").collect(458engine="streaming" if streaming else "in-memory"459)460461assert_frame_equal(out, expect)462463464@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17634")465def test_scan_csv_with_column_names_nonexistent_file() -> None:466path_str = "my-nonexistent-data.csv"467path = Path(path_str)468assert not path.exists()469470# Just calling the scan function should not raise any errors471result = pl.scan_csv(path, with_column_names=lambda x: [c.upper() for c in x])472assert isinstance(result, pl.LazyFrame)473474# Upon collection, it should fail475with pytest.raises(FileNotFoundError):476result.collect()477478479def test_select_nonexistent_column() -> None:480csv = "a\n1"481f = io.StringIO(csv)482483with pytest.raises(pl.exceptions.ColumnNotFoundError):484pl.scan_csv(f).select("b").collect()485486487def test_scan_csv_provided_schema_with_extra_fields_22531() -> None:488data = b"""\489a,b,c490a,b,c491"""492493schema = dict.fromkeys(["a", "b", "c", "d", "e"], pl.String)494495assert_frame_equal(496pl.scan_csv(data, schema=schema).collect(),497pl.DataFrame(498{499"a": "a",500"b": "b",501"c": "c",502"d": None,503"e": None,504},505schema=schema,506),507)508509510def test_csv_negative_slice_comment_char_22996() -> None:511f = b"""\512a,b5131,1514"""515516q = pl.scan_csv(2 * [f], comment_prefix="#").tail(100)517assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 1], "b": [1, 1]}))518519520def test_csv_io_object_utf8_23629() -> None:521n_repeats = 10_000522for df in [523pl.DataFrame({"a": ["é,è"], "b": ["c,d"]}),524pl.DataFrame({"a": ["Ú;и"], "b": ["c;d"]}),525pl.DataFrame({"a": ["a,b"], "b": ["c,d"]}),526pl.DataFrame({"a": ["é," * n_repeats + "è"], "b": ["c," * n_repeats + "d"]}),527]:528# bytes529f_bytes = io.BytesIO()530df.write_csv(f_bytes)531f_bytes.seek(0)532df_bytes = pl.read_csv(f_bytes)533assert_frame_equal(df, df_bytes)534535# str536f_str = io.StringIO()537df.write_csv(f_str)538f_str.seek(0)539df_str = pl.read_csv(f_str)540assert_frame_equal(df, df_str)541542543def test_scan_csv_multiple_files_skip_rows_overflow_26127() -> None:544files: list[IO[bytes]] = [545io.BytesIO(b"foo,bar,baz\n1,2,3\n4,5,6") for _ in range(2)546]547assert_frame_equal(548pl.scan_csv(549files,550n_rows=4,551skip_rows=2,552).collect(),553pl.DataFrame(schema={"4": pl.String, "5": pl.String, "6": pl.String}),554)555556557