Path: blob/main/py-polars/tests/unit/io/test_lazy_count_star.py
8446 views
from __future__ import annotations12import subprocess3import sys4from typing import TYPE_CHECKING56if TYPE_CHECKING:7from pathlib import Path89from polars.lazyframe.frame import LazyFrame10from tests.conftest import PlMonkeyPatch1112import gzip13import re14from tempfile import NamedTemporaryFile1516import pytest1718import polars as pl19from polars.testing import assert_frame_equal202122# Parameters23# * lf: COUNT(*) query24def assert_fast_count(25lf: LazyFrame,26expected_count: int,27*,28expected_name: str = "len",29capfd: pytest.CaptureFixture[str],30plmonkeypatch: PlMonkeyPatch,31) -> None:32capfd.readouterr() # resets stderr3334with plmonkeypatch.context() as cx:35cx.setenv("POLARS_VERBOSE", "1")36result = lf.collect()37capture = capfd.readouterr().err38project_logs = set(re.findall(r"project: \d+", capture))3940# Logs current differ depending on file type / implementation dispatch41if "FAST COUNT" in lf.explain():42# * Should be no projections when fast count is enabled43assert not project_logs44else:45# * Otherwise should have at least one `project: 0` (there is 1 per file).46assert project_logs == {"project: 0"}4748assert result.schema == {expected_name: pl.get_index_type()}49assert result.item() == expected_count5051# We disable the fast-count optimization to check that the normal scan52# logic counts as expected.53plmonkeypatch.setenv("POLARS_NO_FAST_FILE_COUNT", "1")5455capfd.readouterr()5657with plmonkeypatch.context() as cx:58cx.setenv("POLARS_VERBOSE", "1")59assert lf.collect().item() == expected_count6061capture = capfd.readouterr().err62project_logs = set(re.findall(r"project: \d+", capture))6364assert "FAST COUNT" not in lf.explain()65assert project_logs == {"project: 0"}6667plmonkeypatch.setenv("POLARS_NO_FAST_FILE_COUNT", "0")6869plan = lf.explain()70if "Csv" not in plan:71assert "FAST COUNT" not in plan72return7374# CSV is the only format that uses a custom fast-count kernel, so we want75# to make sure that the normal scan logic has the same count behavior. Here76# we restore the default behavior that allows the fast-count optimization.77assert "FAST COUNT" in plan7879capfd.readouterr()8081with plmonkeypatch.context() as cx:82cx.setenv("POLARS_VERBOSE", "1")83assert lf.collect().item() == expected_count8485capture = capfd.readouterr().err86project_logs = set(re.findall(r"project: \d+", capture))8788assert not project_logs899091@pytest.mark.parametrize(92("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)]93)94def test_count_csv(95io_files_path: Path,96path: str,97n_rows: int,98capfd: pytest.CaptureFixture[str],99plmonkeypatch: PlMonkeyPatch,100) -> None:101lf = pl.scan_csv(io_files_path / path).select(pl.len())102103assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)104105106def test_count_csv_comment_char(107capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch108) -> None:109q = pl.scan_csv(110b"""111a,b1121,2113114#1153,4116""",117comment_prefix="#",118)119120assert_frame_equal(121q.collect(), pl.DataFrame({"a": [1, None, 3], "b": [2, None, 4]})122)123124q = q.select(pl.len())125assert_fast_count(q, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)126127128def test_count_csv_no_newline_on_last_22564() -> None:129data = b"""\130a,b1311,21323,41335,6"""134135assert pl.scan_csv(data).collect().height == 3136assert pl.scan_csv(data, comment_prefix="#").collect().height == 3137138assert pl.scan_csv(data).select(pl.len()).collect().item() == 3139assert pl.scan_csv(data, comment_prefix="#").select(pl.len()).collect().item() == 3140141142@pytest.mark.write_disk143def test_commented_csv(144capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch145) -> None:146with NamedTemporaryFile() as csv_a:147csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n")148csv_a.seek(0)149150lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len())151assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)152153lf = pl.scan_csv(154b"AAA",155has_header=False,156comment_prefix="#",157).select(pl.len())158assert_fast_count(lf, 1, capfd=capfd, plmonkeypatch=plmonkeypatch)159160lf = pl.scan_csv(161b"AAA\nBBB",162has_header=False,163comment_prefix="#",164).select(pl.len())165assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)166167lf = pl.scan_csv(168b"AAA\n#comment\nBBB\n#comment",169has_header=False,170comment_prefix="#",171).select(pl.len())172assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)173174lf = pl.scan_csv(175b"AAA\n#comment\nBBB\n#comment\nCCC\n#comment",176has_header=False,177comment_prefix="#",178).select(pl.len())179assert_fast_count(lf, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)180181lf = pl.scan_csv(182b"AAA\n#comment\nBBB\n#comment\nCCC\n#comment\n",183has_header=False,184comment_prefix="#",185).select(pl.len())186assert_fast_count(lf, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)187188189@pytest.mark.parametrize(190("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)]191)192def test_count_parquet(193io_files_path: Path,194pattern: str,195n_rows: int,196capfd: pytest.CaptureFixture[str],197plmonkeypatch: PlMonkeyPatch,198) -> None:199lf = pl.scan_parquet(io_files_path / pattern).select(pl.len())200assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)201202203@pytest.mark.parametrize(204("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)]205)206def test_count_ipc(207io_files_path: Path,208path: str,209n_rows: int,210capfd: pytest.CaptureFixture[str],211plmonkeypatch: PlMonkeyPatch,212) -> None:213lf = pl.scan_ipc(io_files_path / path).select(pl.len())214assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)215216217@pytest.mark.parametrize(218("path", "n_rows"), [("foods1.ndjson", 27), ("foods*.ndjson", 27 * 2)]219)220def test_count_ndjson(221io_files_path: Path,222path: str,223n_rows: int,224capfd: pytest.CaptureFixture[str],225plmonkeypatch: PlMonkeyPatch,226) -> None:227lf = pl.scan_ndjson(io_files_path / path).select(pl.len())228assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)229230231def test_count_compressed_csv_18057(232io_files_path: Path,233capfd: pytest.CaptureFixture[str],234plmonkeypatch: PlMonkeyPatch,235) -> None:236csv_file = io_files_path / "gzipped.csv.gz"237238expected = pl.DataFrame(239{"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]}240)241lf = pl.scan_csv(csv_file, truncate_ragged_lines=True)242out = lf.collect()243assert_frame_equal(out, expected)244# This also tests:245# #18070 "CSV count_rows does not skip empty lines at file start"246# as the file has an empty line at the beginning.247248q = lf.select(pl.len())249assert_fast_count(q, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)250251252@pytest.mark.write_disk253def test_count_compressed_ndjson(254tmp_path: Path, capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch255) -> None:256tmp_path.mkdir(exist_ok=True)257path = tmp_path / "data.jsonl.gz"258df = pl.DataFrame({"x": range(5)})259260with gzip.open(path, "wb") as f:261df.write_ndjson(f) # type: ignore[call-overload]262263lf = pl.scan_ndjson(path).select(pl.len())264assert_fast_count(lf, 5, capfd=capfd, plmonkeypatch=plmonkeypatch)265266267def test_count_projection_pd(268capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch269) -> None:270df = pl.DataFrame({"a": range(3), "b": range(3)})271272q = (273pl.scan_csv(df.write_csv().encode())274.with_row_index()275.select(pl.all())276.select(pl.len())277)278279# Manual assert, this is not converted to FAST COUNT but we will have280# 0-width projections.281282plmonkeypatch.setenv("POLARS_VERBOSE", "1")283capfd.readouterr()284result = q.collect()285capture = capfd.readouterr().err286project_logs = set(re.findall(r"project: \d+", capture))287288assert project_logs == {"project: 0"}289assert result.item() == 3290291292def test_csv_scan_skip_lines_len_22889(293capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch294) -> None:295bb = b"col\n1\n2\n3"296lf = pl.scan_csv(bb, skip_lines=2).select(pl.len())297assert_fast_count(lf, 1, capfd=capfd, plmonkeypatch=plmonkeypatch)298299# trigger multi-threading code path300bb_10k = b"1\n2\n3\n4\n5\n6\n7\n8\n9\n0\n" * 1000301lf = pl.scan_csv(bb_10k, skip_lines=1000, has_header=False).select(pl.len())302assert_fast_count(lf, 9000, capfd=capfd, plmonkeypatch=plmonkeypatch)303304# for comparison305out = pl.scan_csv(bb, skip_lines=2).collect().select(pl.len())306expected = pl.DataFrame({"len": [1]}, schema={"len": pl.get_index_type()})307assert_frame_equal(expected, out)308309310@pytest.mark.write_disk311@pytest.mark.slow312@pytest.mark.parametrize(313"exec_str",314[315"pl.LazyFrame(height=n_rows).select(pl.len()).collect().item()",316"pl.scan_parquet(parquet_file_path).select(pl.len()).collect().item()",317"pl.scan_ipc(ipc_file_path).select(pl.len()).collect().item()",318'pl.LazyFrame({"a": s, "b": s, "c": s}).select("c", "b").collect().height',319"""\320pl.collect_all(321[322pl.scan_parquet(parquet_file_path).select(pl.len()),323pl.scan_ipc(ipc_file_path).select(pl.len()),324pl.LazyFrame(height=n_rows).select(pl.len()),325]326)[0].item()""",327],328)329def test_streaming_fast_count_disables_morsel_split(330tmp_path: Path, exec_str: str331) -> None:332n_rows = (1 << 32) - 2333parquet_file_path = tmp_path / "data.parquet"334ipc_file_path = tmp_path / "data.ipc"335336script_args = [str(n_rows), str(parquet_file_path), str(ipc_file_path), exec_str]337338# We spawn 2 processes - the first process sets a huge ideal morsel size to339# generate the data quickly. The 2nd process sets the ideal morsel size to 1,340# making it so that if morsel splitting is performed it would exceed the341# timeout of 5 seconds.342343assert (344subprocess.check_output(345[346sys.executable,347"-c",348"""\349import os350import sys351352os.environ["POLARS_IDEAL_MORSEL_SIZE"] = str(1_000_000_000)353354import polars as pl355356pl.Config.set_engine_affinity("streaming")357358(359_,360n_rows,361parquet_file_path,362ipc_file_path,363_,364) = sys.argv365366n_rows = int(n_rows)367368pl.LazyFrame(height=n_rows).sink_parquet(parquet_file_path, row_group_size=1_000_000_000)369pl.LazyFrame(height=n_rows).sink_ipc(ipc_file_path, record_batch_size=1_000_000_000)370371print("OK", end="")372""",373*script_args,374],375timeout=5,376)377== b"OK"378)379380assert (381subprocess.check_output(382[383sys.executable,384"-c",385"""\386import os387import sys388389os.environ["POLARS_IDEAL_MORSEL_SIZE"] = "1"390391import polars as pl392393pl.Config.set_engine_affinity("streaming")394395(396_,397n_rows,398parquet_file_path,399ipc_file_path,400exec_str,401) = sys.argv402403n_rows = int(n_rows)404405s = pl.Series([{}], dtype=pl.Struct({})).new_from_index(0, n_rows)406assert eval(exec_str) == n_rows407408print("OK", end="")409""",410*script_args,411],412timeout=5,413)414== b"OK"415)416417418