Path: blob/main/py-polars/tests/unit/streaming/test_streaming_group_by.py
6939 views
from __future__ import annotations12from datetime import date3from typing import TYPE_CHECKING, Any45import numpy as np6import pytest78import polars as pl9from polars.exceptions import DuplicateError10from polars.testing import assert_frame_equal11from tests.unit.conftest import INTEGER_DTYPES1213if TYPE_CHECKING:14from pathlib import Path1516pytestmark = pytest.mark.xdist_group("streaming")171819@pytest.mark.slow20def test_streaming_group_by_sorted_fast_path_nulls_10273() -> None:21df = pl.Series(22name="x",23values=(24*(i for i in range(4) for _ in range(100)),25*(None for _ in range(100)),26),27).to_frame()2829assert (30df.set_sorted("x")31.lazy()32.group_by("x")33.agg(pl.len())34.collect(engine="streaming")35.sort("x")36).to_dict(as_series=False) == {37"x": [None, 0, 1, 2, 3],38"len": [100, 100, 100, 100, 100],39}404142def test_streaming_group_by_types() -> None:43df = pl.DataFrame(44{45"person_id": [1, 1],46"year": [1995, 1995],47"person_name": ["bob", "foo"],48"bool": [True, False],49"date": [date(2022, 1, 1), date(2022, 1, 1)],50}51)5253for by in ["person_id", "year", "date", ["person_id", "year"]]:54out = (55(56df.lazy()57.group_by(by)58.agg(59[60pl.col("person_name").first().alias("str_first"),61pl.col("person_name").last().alias("str_last"),62pl.col("person_name").mean().alias("str_mean"),63pl.col("person_name").sum().alias("str_sum"),64pl.col("bool").first().alias("bool_first"),65pl.col("bool").last().alias("bool_last"),66pl.col("bool").mean().alias("bool_mean"),67pl.col("bool").sum().alias("bool_sum"),68# pl.col("date").sum().alias("date_sum"),69# Date streaming mean/median has been temporarily disabled70# pl.col("date").mean().alias("date_mean"),71pl.col("date").first().alias("date_first"),72pl.col("date").last().alias("date_last"),73pl.col("date").min().alias("date_min"),74pl.col("date").max().alias("date_max"),75]76)77)78.select(pl.all().exclude(by))79.collect(engine="streaming")80)81assert out.schema == {82"str_first": pl.String,83"str_last": pl.String,84"str_mean": pl.String,85"str_sum": pl.String,86"bool_first": pl.Boolean,87"bool_last": pl.Boolean,88"bool_mean": pl.Float64,89"bool_sum": pl.UInt32,90# "date_sum": pl.Date,91# "date_mean": pl.Date,92"date_first": pl.Date,93"date_last": pl.Date,94"date_min": pl.Date,95"date_max": pl.Date,96}9798assert out.to_dict(as_series=False) == {99"str_first": ["bob"],100"str_last": ["foo"],101"str_mean": [None],102"str_sum": [None],103"bool_first": [True],104"bool_last": [False],105"bool_mean": [0.5],106"bool_sum": [1],107# "date_sum": [None],108# Date streaming mean/median has been temporarily disabled109# "date_mean": [date(2022, 1, 1)],110"date_first": [date(2022, 1, 1)],111"date_last": [date(2022, 1, 1)],112"date_min": [date(2022, 1, 1)],113"date_max": [date(2022, 1, 1)],114}115116with pytest.raises(DuplicateError):117(118df.lazy()119.group_by("person_id")120.agg(121[122pl.col("person_name").first().alias("str_first"),123pl.col("person_name").last().alias("str_last"),124pl.col("person_name").mean().alias("str_mean"),125pl.col("person_name").sum().alias("str_sum"),126pl.col("bool").first().alias("bool_first"),127pl.col("bool").last().alias("bool_first"),128]129)130.select(pl.all().exclude("person_id"))131.collect(engine="streaming")132)133134135def test_streaming_group_by_min_max() -> None:136df = pl.DataFrame(137{138"person_id": [1, 2, 3, 4, 5, 6],139"year": [1995, 1995, 1995, 2, 2, 2],140}141)142out = (143df.lazy()144.group_by("year")145.agg([pl.min("person_id").alias("min"), pl.max("person_id").alias("max")])146.collect()147.sort("year")148)149assert out["min"].to_list() == [4, 1]150assert out["max"].to_list() == [6, 3]151152153def test_streaming_non_streaming_gb() -> None:154n = 100155df = pl.DataFrame({"a": np.random.randint(0, 20, n)})156q = df.lazy().group_by("a").agg(pl.len()).sort("a")157assert_frame_equal(q.collect(engine="streaming"), q.collect())158159q = df.lazy().with_columns(pl.col("a").cast(pl.String))160q = q.group_by("a").agg(pl.len()).sort("a")161assert_frame_equal(q.collect(engine="streaming"), q.collect())162q = df.lazy().with_columns(pl.col("a").alias("b"))163q = q.group_by(["a", "b"]).agg(pl.len(), pl.col("a").sum().alias("sum_a")).sort("a")164assert_frame_equal(q.collect(engine="streaming"), q.collect())165166167def test_streaming_group_by_sorted_fast_path() -> None:168a = np.random.randint(0, 20, 80)169df = pl.DataFrame(170{171# test on int8 as that also tests proper conversions172"a": pl.Series(np.sort(a), dtype=pl.Int8)173}174).with_row_index()175176df_sorted = df.with_columns(pl.col("a").set_sorted())177178for streaming in [True, False]:179results = []180for df_ in [df, df_sorted]:181out = (182df_.lazy()183.group_by("a")184.agg(185[186pl.first("a").alias("first"),187pl.last("a").alias("last"),188pl.sum("a").alias("sum"),189pl.mean("a").alias("mean"),190pl.count("a").alias("count"),191pl.min("a").alias("min"),192pl.max("a").alias("max"),193]194)195.sort("a")196.collect(engine="streaming" if streaming else "in-memory")197)198results.append(out)199200assert_frame_equal(results[0], results[1])201202203@pytest.fixture(scope="module")204def random_integers() -> pl.Series:205np.random.seed(1)206return pl.Series("a", np.random.randint(0, 10, 100), dtype=pl.Int64)207208209@pytest.mark.write_disk210def test_streaming_group_by_ooc_q1(211random_integers: pl.Series,212tmp_path: Path,213monkeypatch: Any,214) -> None:215tmp_path.mkdir(exist_ok=True)216monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))217monkeypatch.setenv("POLARS_FORCE_OOC", "1")218219lf = random_integers.to_frame().lazy()220result = (221lf.group_by("a")222.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))223.sort("a")224.collect(engine="streaming")225)226227expected = pl.DataFrame(228{229"a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],230"a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],231"a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],232}233)234assert_frame_equal(result, expected)235236237@pytest.mark.write_disk238def test_streaming_group_by_ooc_q2(239random_integers: pl.Series,240tmp_path: Path,241monkeypatch: Any,242) -> None:243tmp_path.mkdir(exist_ok=True)244monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))245monkeypatch.setenv("POLARS_FORCE_OOC", "1")246247lf = random_integers.cast(str).to_frame().lazy()248result = (249lf.group_by("a")250.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))251.sort("a")252.collect(engine="streaming")253)254255expected = pl.DataFrame(256{257"a": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],258"a_first": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],259"a_last": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],260}261)262assert_frame_equal(result, expected)263264265@pytest.mark.write_disk266def test_streaming_group_by_ooc_q3(267random_integers: pl.Series,268tmp_path: Path,269monkeypatch: Any,270) -> None:271tmp_path.mkdir(exist_ok=True)272monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path))273monkeypatch.setenv("POLARS_FORCE_OOC", "1")274275lf = pl.LazyFrame({"a": random_integers, "b": random_integers})276result = (277lf.group_by("a", "b")278.agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last"))279.sort("a")280.collect(engine="streaming")281)282283expected = pl.DataFrame(284{285"a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],286"b": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],287"a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],288"a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],289}290)291assert_frame_equal(result, expected)292293294def test_streaming_group_by_struct_key() -> None:295df = pl.DataFrame(296{"A": [1, 2, 3, 2], "B": ["google", "ms", "apple", "ms"], "C": [2, 3, 4, 3]}297)298df1 = df.lazy().with_columns(pl.struct(["A", "C"]).alias("tuples"))299assert df1.group_by("tuples").agg(pl.len(), pl.col("B").first()).sort("B").collect(300engine="streaming"301).to_dict(as_series=False) == {302"tuples": [{"A": 3, "C": 4}, {"A": 1, "C": 2}, {"A": 2, "C": 3}],303"len": [1, 1, 2],304"B": ["apple", "google", "ms"],305}306307308@pytest.mark.slow309def test_streaming_group_by_all_numeric_types_stability_8570() -> None:310m = 1000311n = 1000312313rng = np.random.default_rng(seed=0)314dfa = pl.DataFrame({"x": pl.arange(start=0, end=n, eager=True)})315dfb = pl.DataFrame(316{317"y": rng.integers(low=0, high=10, size=m),318"z": rng.integers(low=0, high=2, size=m),319}320)321dfc = dfa.join(dfb, how="cross")322323for keys in [["x", "y"], "z"]:324for dtype in [*INTEGER_DTYPES, pl.Boolean]:325# the alias checks if the schema is correctly handled326dfd = (327dfc.lazy()328.with_columns(pl.col("z").cast(dtype))329.group_by(keys)330.agg(pl.col("z").sum().alias("z_sum"))331.collect(engine="streaming")332)333assert dfd["z_sum"].sum() == dfc["z"].sum()334335336def test_streaming_group_by_categorical_aggregate() -> None:337out = (338pl.LazyFrame(339{340"a": pl.Series(341["a", "a", "b", "b", "c", "c", None, None], dtype=pl.Categorical342),343"b": pl.Series(344pl.date_range(345date(2023, 4, 28),346date(2023, 5, 5),347eager=True,348).to_list(),349dtype=pl.Date,350),351}352)353.group_by(["a", "b"])354.agg([pl.col("a").first().alias("sum")])355.collect(engine="streaming")356)357358assert out.sort("b").to_dict(as_series=False) == {359"a": ["a", "a", "b", "b", "c", "c", None, None],360"b": [361date(2023, 4, 28),362date(2023, 4, 29),363date(2023, 4, 30),364date(2023, 5, 1),365date(2023, 5, 2),366date(2023, 5, 3),367date(2023, 5, 4),368date(2023, 5, 5),369],370"sum": ["a", "a", "b", "b", "c", "c", None, None],371}372373374def test_streaming_group_by_list_9758() -> None:375payload = {"a": [[1, 2]]}376assert (377pl.LazyFrame(payload)378.group_by("a")379.first()380.collect(engine="streaming")381.to_dict(as_series=False)382== payload383)384385386def test_group_by_min_max_string_type() -> None:387table = pl.from_dict({"a": [1, 1, 2, 2, 2], "b": ["a", "b", "c", "d", None]})388389expected = {"a": [1, 2], "min": ["a", "c"], "max": ["b", "d"]}390391for streaming in [True, False]:392assert (393table.lazy()394.group_by("a")395.agg([pl.min("b").alias("min"), pl.max("b").alias("max")])396.collect(engine="streaming" if streaming else "in-memory")397.sort("a")398.to_dict(as_series=False)399== expected400)401402403@pytest.mark.parametrize("literal", [True, "foo", 1])404def test_streaming_group_by_literal(literal: Any) -> None:405df = pl.LazyFrame({"a": range(20)})406407assert df.group_by(pl.lit(literal)).agg(408[409pl.col("a").count().alias("a_count"),410pl.col("a").sum().alias("a_sum"),411]412).collect(engine="streaming").to_dict(as_series=False) == {413"literal": [literal],414"a_count": [20],415"a_sum": [190],416}417418419@pytest.mark.parametrize("streaming", [True, False])420def test_group_by_multiple_keys_one_literal(streaming: bool) -> None:421df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})422423expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]}424assert (425df.lazy()426.group_by("a", pl.lit(1))427.agg(pl.col("b").max())428.sort(["a", "b"])429.collect(engine="streaming" if streaming else "in-memory")430.to_dict(as_series=False)431== expected432)433434435def test_streaming_group_null_count() -> None:436df = pl.DataFrame({"g": [1] * 6, "a": ["yes", None] * 3}).lazy()437assert df.group_by("g").agg(pl.col("a").count()).collect(438engine="streaming"439).to_dict(as_series=False) == {"g": [1], "a": [3]}440441442def test_streaming_group_by_binary_15116() -> None:443assert (444pl.LazyFrame(445{446"str": [447"A",448"A",449"BB",450"BB",451"CCCC",452"CCCC",453"DDDDDDDD",454"DDDDDDDD",455"EEEEEEEEEEEEEEEE",456"A",457]458}459)460.select([pl.col("str").cast(pl.Binary)])461.group_by(["str"])462.agg([pl.len().alias("count")])463).sort("str").collect(engine="streaming").to_dict(as_series=False) == {464"str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"],465"count": [3, 2, 2, 2, 1],466}467468469def test_streaming_group_by_convert_15380(partition_limit: int) -> None:470assert (471pl.DataFrame({"a": [1] * partition_limit}).group_by(b="a").len()["len"].item()472== partition_limit473)474475476@pytest.mark.parametrize("streaming", [True, False])477@pytest.mark.parametrize("n_rows_limit_offset", [-1, +3])478def test_streaming_group_by_boolean_mean_15610(479n_rows_limit_offset: int, streaming: bool, partition_limit: int480) -> None:481n_rows = partition_limit + n_rows_limit_offset482483# Also test non-streaming because it sometimes dispatched to streaming agg.484expect = pl.DataFrame({"a": [False, True], "c": [0.0, 0.5]})485486n_repeats = n_rows // 3487assert n_repeats > 0488489out = (490pl.select(491a=pl.repeat([True, False, True], n_repeats).explode(),492b=pl.repeat([True, False, False], n_repeats).explode(),493)494.lazy()495.group_by("a")496.agg(c=pl.mean("b"))497.sort("a")498.collect(engine="streaming" if streaming else "in-memory")499)500501assert_frame_equal(out, expect)502503504def test_streaming_group_by_all_null_21593() -> None:505df = pl.DataFrame(506{507"col_1": ["A", "B", "C", "D"],508"col_2": ["test", None, None, None],509}510)511512out = df.lazy().group_by(pl.all()).min().collect(engine="streaming")513assert_frame_equal(df, out, check_row_order=False)514515516