Path: blob/main/py-polars/tests/unit/io/test_lazy_count_star.py
6939 views
from __future__ import annotations12from typing import TYPE_CHECKING34if TYPE_CHECKING:5from pathlib import Path67from polars.lazyframe.frame import LazyFrame89import gzip10import re11from tempfile import NamedTemporaryFile1213import pytest1415import polars as pl16from polars.testing import assert_frame_equal171819# Parameters20# * lf: COUNT(*) query21def assert_fast_count(22lf: LazyFrame,23expected_count: int,24*,25expected_name: str = "len",26capfd: pytest.CaptureFixture[str],27monkeypatch: pytest.MonkeyPatch,28) -> None:29monkeypatch.setenv("POLARS_VERBOSE", "1")3031capfd.readouterr() # resets stderr32result = lf.collect()33capture = capfd.readouterr().err34project_logs = set(re.findall(r"project: \d+", capture))3536# Logs current differ depending on file type / implementation dispatch37if "FAST COUNT" in lf.explain():38# * Should be no projections when fast count is enabled39assert not project_logs40else:41# * Otherwise should have at least one `project: 0` (there is 1 per file).42assert project_logs == {"project: 0"}4344assert result.schema == {expected_name: pl.get_index_type()}45assert result.item() == expected_count4647# Test effect of the environment variable48monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "0")4950capfd.readouterr()51lf.collect()52capture = capfd.readouterr().err53project_logs = set(re.findall(r"project: \d+", capture))5455assert "FAST COUNT" not in lf.explain()56assert project_logs == {"project: 0"}5758monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "1")5960capfd.readouterr()61lf.collect()62capture = capfd.readouterr().err63project_logs = set(re.findall(r"project: \d+", capture))6465assert "FAST COUNT" in lf.explain()66assert not project_logs676869@pytest.mark.parametrize(70("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)]71)72def test_count_csv(73io_files_path: Path,74path: str,75n_rows: int,76capfd: pytest.CaptureFixture[str],77monkeypatch: pytest.MonkeyPatch,78) -> None:79lf = pl.scan_csv(io_files_path / path).select(pl.len())8081assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)828384def test_count_csv_comment_char(85capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch86) -> None:87q = pl.scan_csv(88b"""89a,b901,29192#933,494""",95comment_prefix="#",96)9798assert_frame_equal(99q.collect(), pl.DataFrame({"a": [1, None, 3], "b": [2, None, 4]})100)101102q = q.select(pl.len())103assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch)104105106def test_count_csv_no_newline_on_last_22564() -> None:107data = b"""\108a,b1091,21103,41115,6"""112113assert pl.scan_csv(data).collect().height == 3114assert pl.scan_csv(data, comment_prefix="#").collect().height == 3115116assert pl.scan_csv(data).select(pl.len()).collect().item() == 3117assert pl.scan_csv(data, comment_prefix="#").select(pl.len()).collect().item() == 3118119120@pytest.mark.write_disk121def test_commented_csv(122capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch123) -> None:124with NamedTemporaryFile() as csv_a:125csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n")126csv_a.seek(0)127128lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len())129assert_fast_count(lf, 2, capfd=capfd, monkeypatch=monkeypatch)130131132@pytest.mark.parametrize(133("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)]134)135def test_count_parquet(136io_files_path: Path,137pattern: str,138n_rows: int,139capfd: pytest.CaptureFixture[str],140monkeypatch: pytest.MonkeyPatch,141) -> None:142lf = pl.scan_parquet(io_files_path / pattern).select(pl.len())143assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)144145146@pytest.mark.parametrize(147("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)]148)149def test_count_ipc(150io_files_path: Path,151path: str,152n_rows: int,153capfd: pytest.CaptureFixture[str],154monkeypatch: pytest.MonkeyPatch,155) -> None:156lf = pl.scan_ipc(io_files_path / path).select(pl.len())157assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)158159160@pytest.mark.parametrize(161("path", "n_rows"), [("foods1.ndjson", 27), ("foods*.ndjson", 27 * 2)]162)163def test_count_ndjson(164io_files_path: Path,165path: str,166n_rows: int,167capfd: pytest.CaptureFixture[str],168monkeypatch: pytest.MonkeyPatch,169) -> None:170lf = pl.scan_ndjson(io_files_path / path).select(pl.len())171assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)172173174def test_count_compressed_csv_18057(175io_files_path: Path,176capfd: pytest.CaptureFixture[str],177monkeypatch: pytest.MonkeyPatch,178) -> None:179csv_file = io_files_path / "gzipped.csv.gz"180181expected = pl.DataFrame(182{"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]}183)184lf = pl.scan_csv(csv_file, truncate_ragged_lines=True)185out = lf.collect()186assert_frame_equal(out, expected)187# This also tests:188# #18070 "CSV count_rows does not skip empty lines at file start"189# as the file has an empty line at the beginning.190191q = lf.select(pl.len())192assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch)193194195def test_count_compressed_ndjson(196tmp_path: Path, capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch197) -> None:198tmp_path.mkdir(exist_ok=True)199path = tmp_path / "data.jsonl.gz"200df = pl.DataFrame({"x": range(5)})201202with gzip.open(path, "wb") as f:203df.write_ndjson(f) # type: ignore[call-overload]204205lf = pl.scan_ndjson(path).select(pl.len())206assert_fast_count(lf, 5, capfd=capfd, monkeypatch=monkeypatch)207208209def test_count_projection_pd(210capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch211) -> None:212df = pl.DataFrame({"a": range(3), "b": range(3)})213214q = (215pl.scan_csv(df.write_csv().encode())216.with_row_index()217.select(pl.all())218.select(pl.len())219)220221# Manual assert, this is not converted to FAST COUNT but we will have222# 0-width projections.223224monkeypatch.setenv("POLARS_VERBOSE", "1")225capfd.readouterr()226result = q.collect()227capture = capfd.readouterr().err228project_logs = set(re.findall(r"project: \d+", capture))229230assert project_logs == {"project: 0"}231assert result.item() == 3232233234def test_csv_scan_skip_lines_len_22889(235capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch236) -> None:237bb = b"col\n1\n2\n3"238lf = pl.scan_csv(bb, skip_lines=2).select(pl.len())239assert_fast_count(lf, 1, capfd=capfd, monkeypatch=monkeypatch)240241## trigger multi-threading code path242bb_10k = b"1\n2\n3\n4\n5\n6\n7\n8\n9\n0\n" * 1000243lf = pl.scan_csv(bb_10k, skip_lines=1000, has_header=False).select(pl.len())244assert_fast_count(lf, 9000, capfd=capfd, monkeypatch=monkeypatch)245246# for comparison247out = pl.scan_csv(bb, skip_lines=2).collect().select(pl.len())248expected = pl.DataFrame({"len": [1]}, schema={"len": pl.UInt32})249assert_frame_equal(expected, out)250251252